mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
fix: add browser-level download monitor for CDP downloads bypassing Fetch (#5089)
This commit is contained in:
parent
cece22f21b
commit
e80ded3a97
2 changed files with 298 additions and 50 deletions
|
|
@ -21,13 +21,16 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import structlog
|
||||
from playwright.async_api import CDPSession, Page
|
||||
from playwright.async_api import Browser, BrowserContext, CDPSession, Page
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
|
@ -161,7 +164,7 @@ def is_download_response(headers: dict[str, str], status_code: int, resource_typ
|
|||
return False
|
||||
|
||||
|
||||
def extract_filename(headers: dict[str, str], url: str, index: int) -> str:
|
||||
def extract_filename(headers: dict[str, str], url: str) -> str:
|
||||
"""
|
||||
Extract filename from response headers or URL.
|
||||
|
||||
|
|
@ -169,7 +172,7 @@ def extract_filename(headers: dict[str, str], url: str, index: int) -> str:
|
|||
1. Content-Disposition filename*= (RFC 5987, UTF-8)
|
||||
2. Content-Disposition filename=
|
||||
3. URL path last segment (if it has an extension)
|
||||
4. Fallback: download_{timestamp}_{index}
|
||||
4. Empty string (caller is responsible for fallback via _resolve_save_path)
|
||||
"""
|
||||
content_disposition = headers.get("content-disposition", "")
|
||||
|
||||
|
|
@ -192,7 +195,7 @@ def extract_filename(headers: dict[str, str], url: str, index: int) -> str:
|
|||
if "." in last_segment:
|
||||
return last_segment
|
||||
|
||||
return f"download_{int(time.time())}_{index}"
|
||||
return ""
|
||||
|
||||
|
||||
class CDPDownloadInterceptor:
|
||||
|
|
@ -223,6 +226,10 @@ class CDPDownloadInterceptor:
|
|||
# Track auth attempts per requestId to prevent infinite retry loops
|
||||
# when proxy credentials are rejected (407 → ProvideCredentials → 407 → …)
|
||||
self._auth_attempts: dict[str, int] = {}
|
||||
# Track URLs already downloaded (dedup between Fetch interception and browser download monitor)
|
||||
self._downloaded_urls: set[str] = set()
|
||||
self._browser_session: CDPSession | None = None
|
||||
self._browser_context: BrowserContext | None = None
|
||||
|
||||
def set_download_dir(self, download_dir: str) -> None:
|
||||
"""Set or update the download directory. Can be called after init when run_id becomes available."""
|
||||
|
|
@ -230,6 +237,31 @@ class CDPDownloadInterceptor:
|
|||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||
LOG.info("CDP download interceptor download dir set", download_dir=download_dir)
|
||||
|
||||
def _resolve_save_path(self, filename: str = "") -> tuple[Path, str]:
|
||||
"""Generate a unique save path under _output_dir.
|
||||
|
||||
Sanitizes the filename (path traversal prevention), falls back to a UUID-based
|
||||
name when empty, increments _download_index, and logs a warning if a file with
|
||||
the same name already exists. Returns (save_path, sanitized_filename).
|
||||
|
||||
Callers can pass a raw or empty filename — this method handles all normalization.
|
||||
"""
|
||||
assert self._output_dir is not None
|
||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._download_index += 1
|
||||
# Sanitize to prevent path traversal (e.g. "../../etc/evil")
|
||||
filename = Path(filename).name
|
||||
if not filename:
|
||||
filename = f"download_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
save_path = self._output_dir / filename
|
||||
# TODO: implement proper filename dedup (e.g., content hash or UUID suffix)
|
||||
if save_path.exists():
|
||||
LOG.warning("Download filename collision, overwriting", filename=filename, save_path=str(save_path))
|
||||
|
||||
return save_path, filename
|
||||
|
||||
async def enable_for_page(self, page: Page) -> None:
|
||||
"""Create a CDP session for the given page and enable Fetch interception.
|
||||
|
||||
|
|
@ -274,8 +306,157 @@ class CDPDownloadInterceptor:
|
|||
proxy_auth_enabled=has_proxy_auth,
|
||||
)
|
||||
|
||||
async def enable_browser_download_monitor(self, browser: Browser, browser_context: BrowserContext) -> None:
|
||||
"""Monitor browser-initiated downloads and save them directly via HTTP.
|
||||
|
||||
Many sites trigger downloads via mechanisms that bypass CDP Fetch
|
||||
(e.g., new tab for signed URL, <a download>, blob URLs). The browser's
|
||||
download manager handles these directly — no page-level network request occurs.
|
||||
|
||||
This method uses Browser-level CDP events to detect such downloads,
|
||||
then downloads the file directly via HTTP using the BrowserContext's
|
||||
APIRequestContext (which shares cookies and outlives individual pages).
|
||||
"""
|
||||
if self._browser_session is not None:
|
||||
LOG.warning("Browser download monitor already enabled, skipping")
|
||||
return
|
||||
|
||||
browser_session = await browser.new_browser_cdp_session()
|
||||
self._browser_session = browser_session
|
||||
self._browser_context = browser_context
|
||||
|
||||
# Deny browser-native downloads — we download files ourselves via HTTP.
|
||||
# Using "deny" instead of "allowAndName" avoids needing a downloadPath, which is
|
||||
# critical for remote CDP browsers: downloadPath is interpreted on the browser's
|
||||
# filesystem, not the client's, so a local tempdir path would be invalid.
|
||||
# Browser.downloadWillBegin events still fire with eventsEnabled=True, giving us
|
||||
# the URL to download directly.
|
||||
await browser_session.send(
|
||||
"Browser.setDownloadBehavior",
|
||||
{"behavior": "deny", "eventsEnabled": True},
|
||||
)
|
||||
|
||||
browser_session.on(
|
||||
"Browser.downloadWillBegin",
|
||||
lambda event: asyncio.ensure_future(self._handle_browser_download(event)),
|
||||
)
|
||||
LOG.info("Browser download monitor enabled")
|
||||
|
||||
async def _handle_browser_download(self, event: dict[str, Any]) -> None:
|
||||
"""Handle Browser.downloadWillBegin — download the file via HTTP or blob read."""
|
||||
try:
|
||||
url = event.get("url", "")
|
||||
suggested_filename = event.get("suggestedFilename", "")
|
||||
LOG.info(
|
||||
"Browser download detected",
|
||||
url=url,
|
||||
suggested_filename=suggested_filename,
|
||||
)
|
||||
if not url:
|
||||
LOG.warning("Empty download URL, skipping")
|
||||
return
|
||||
|
||||
# Skip if already downloaded via CDP Fetch interception.
|
||||
# Fetch always fires before Browser.downloadWillBegin (Fetch intercepts at
|
||||
# response stage, browser download manager fires after fulfillRequest), so
|
||||
# this check is purely one-directional: only _handle_download writes the set.
|
||||
if url in self._downloaded_urls:
|
||||
LOG.debug("URL already captured via Fetch, skipping direct download", url=url)
|
||||
return
|
||||
|
||||
if url.startswith("blob:"):
|
||||
# blob: URLs are in-memory browser references — not fetchable over HTTP.
|
||||
# They are already handled by the Fetch path which intercepts the resolved blob.
|
||||
# TODO: handle the edge case where Fetch doesn't catch a blob download.
|
||||
LOG.warning(
|
||||
"blob: URL download not yet supported, skipping", url=url, suggested_filename=suggested_filename
|
||||
)
|
||||
return
|
||||
elif url.startswith("http"):
|
||||
await self._download_url_directly(url, suggested_filename)
|
||||
else:
|
||||
LOG.warning("Download URL scheme not supported, skipping", url=url)
|
||||
except Exception:
|
||||
LOG.warning("Error handling browser download event", exc_info=True)
|
||||
|
||||
async def _download_url_directly(self, url: str, suggested_filename: str) -> None:
|
||||
"""Download a URL directly via HTTP and save to the output directory.
|
||||
|
||||
Tries Playwright's APIRequestContext first (shares browser context cookies),
|
||||
falls back to urllib for pre-signed URLs or when APIRequestContext fails.
|
||||
"""
|
||||
if not self._output_dir:
|
||||
LOG.warning("No output_dir set, skipping direct download", url=url)
|
||||
return
|
||||
|
||||
save_path, filename = self._resolve_save_path(suggested_filename)
|
||||
|
||||
t0 = time.monotonic()
|
||||
data: bytes | None = None
|
||||
method = ""
|
||||
|
||||
# Try Playwright's APIRequestContext which shares the BrowserContext's cookies.
|
||||
# We use the BrowserContext (not a Page) so this survives individual page closes.
|
||||
if self._browser_context:
|
||||
try:
|
||||
response = await self._browser_context.request.get(url)
|
||||
if response.ok:
|
||||
data = await response.body()
|
||||
method = "playwright_api"
|
||||
else:
|
||||
LOG.debug(
|
||||
"Playwright APIRequestContext returned non-OK status, trying urllib",
|
||||
url=url,
|
||||
status=response.status,
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.debug("Playwright APIRequestContext download failed, trying urllib", url=url, error=str(e))
|
||||
|
||||
# Fallback: direct HTTP via urllib (works for pre-signed URLs)
|
||||
if data is None:
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
|
||||
def _fetch() -> bytes:
|
||||
with urllib.request.urlopen(req, context=ssl_ctx) as resp:
|
||||
return resp.read()
|
||||
|
||||
data = await asyncio.to_thread(_fetch)
|
||||
method = "urllib"
|
||||
except Exception as e:
|
||||
LOG.error("Direct HTTP download failed", url=url, error=str(e), exc_info=True)
|
||||
return
|
||||
|
||||
if data is None:
|
||||
LOG.error("Download produced no data", url=url)
|
||||
return
|
||||
|
||||
if len(data) > MAX_FILE_SIZE_BYTES:
|
||||
LOG.warning(
|
||||
"Direct download exceeds size limit, discarding",
|
||||
url=url,
|
||||
size=len(data),
|
||||
max_size=MAX_FILE_SIZE_BYTES,
|
||||
)
|
||||
return
|
||||
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
LOG.info(
|
||||
"CDP download saved (direct HTTP)",
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
duration_ms=round(elapsed_ms, 1),
|
||||
save_path=str(save_path),
|
||||
download_index=self._download_index,
|
||||
method=method,
|
||||
)
|
||||
|
||||
async def disable(self) -> None:
|
||||
"""Disable Fetch interception on all CDP sessions."""
|
||||
"""Disable Fetch interception on all CDP sessions and clean up browser monitor."""
|
||||
session_count = len(self._cdp_sessions)
|
||||
for cdp_session in self._cdp_sessions:
|
||||
try:
|
||||
|
|
@ -283,6 +464,16 @@ class CDPDownloadInterceptor:
|
|||
except Exception:
|
||||
pass
|
||||
self._cdp_sessions.clear()
|
||||
|
||||
# Clean up browser-level download monitor session
|
||||
if self._browser_session:
|
||||
try:
|
||||
await self._browser_session.detach()
|
||||
except Exception:
|
||||
pass
|
||||
self._browser_session = None
|
||||
self._browser_context = None
|
||||
|
||||
self._enabled = False
|
||||
LOG.info(
|
||||
"CDP Fetch interception disabled",
|
||||
|
|
@ -387,6 +578,15 @@ class CDPDownloadInterceptor:
|
|||
response_headers = _parse_headers(raw_response_headers)
|
||||
resource_type = event.get("resourceType", "")
|
||||
|
||||
LOG.debug(
|
||||
"CDP Fetch response paused",
|
||||
url=url,
|
||||
resource_type=resource_type,
|
||||
status_code=response_status,
|
||||
content_type=response_headers.get("content-type", ""),
|
||||
content_disposition=response_headers.get("content-disposition", ""),
|
||||
)
|
||||
|
||||
if is_download_response(response_headers, response_status, resource_type):
|
||||
LOG.info(
|
||||
"CDP download response detected",
|
||||
|
|
@ -439,24 +639,10 @@ class CDPDownloadInterceptor:
|
|||
await self._continue_response(cdp_session, request_id)
|
||||
return
|
||||
|
||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._download_index += 1
|
||||
content_length = _parse_content_length(headers)
|
||||
content_type = headers.get("content-type", "").split(";")[0].strip()
|
||||
filename = extract_filename(headers, url, self._download_index)
|
||||
# Sanitize filename to prevent path traversal (e.g. "../../etc/evil")
|
||||
filename = Path(filename).name
|
||||
if not filename:
|
||||
filename = f"download_{int(time.time())}_{self._download_index}"
|
||||
save_path = self._output_dir / filename
|
||||
|
||||
# Deduplicate filename if a file with the same name already exists
|
||||
if save_path.exists():
|
||||
stem = Path(filename).stem
|
||||
suffix = Path(filename).suffix
|
||||
filename = f"{stem}_{self._download_index}{suffix}"
|
||||
save_path = self._output_dir / filename
|
||||
raw_filename = extract_filename(headers, url)
|
||||
save_path, filename = self._resolve_save_path(raw_filename)
|
||||
|
||||
LOG.info(
|
||||
"CDP download detected",
|
||||
|
|
@ -476,6 +662,13 @@ class CDPDownloadInterceptor:
|
|||
await self._continue_response(cdp_session, request_id)
|
||||
return
|
||||
|
||||
# Mark URL as handled BEFORE starting the (potentially slow) body extraction.
|
||||
# This prevents the browser download monitor (_handle_browser_download) from
|
||||
# racing to download the same URL while we're still streaming the body.
|
||||
# We intentionally do NOT remove the URL on failure — if Fetch extraction fails,
|
||||
# a direct HTTP re-download of the same URL would likely fail too.
|
||||
self._downloaded_urls.add(url)
|
||||
|
||||
t0 = time.monotonic()
|
||||
|
||||
try:
|
||||
|
|
@ -497,7 +690,6 @@ class CDPDownloadInterceptor:
|
|||
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
LOG.info(
|
||||
"CDP download saved",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Unit tests for CDPDownloadInterceptor pure functions and proxy auth handling."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -160,76 +160,132 @@ class TestIsDownloadResponse:
|
|||
|
||||
|
||||
class TestExtractFilename:
|
||||
"""Tests for extract_filename()."""
|
||||
"""Tests for extract_filename().
|
||||
|
||||
extract_filename returns an empty string when no filename can be determined —
|
||||
the caller (_resolve_save_path) is responsible for generating a fallback name.
|
||||
"""
|
||||
|
||||
def test_rfc5987_filename_star(self) -> None:
|
||||
headers = {"content-disposition": "attachment; filename*=UTF-8''my%20report%282024%29.pdf"}
|
||||
result = extract_filename(headers, "https://example.com/download", 1)
|
||||
result = extract_filename(headers, "https://example.com/download")
|
||||
assert result == "my report(2024).pdf"
|
||||
|
||||
def test_regular_filename(self) -> None:
|
||||
headers = {"content-disposition": 'attachment; filename="report.csv"'}
|
||||
result = extract_filename(headers, "https://example.com/download", 1)
|
||||
result = extract_filename(headers, "https://example.com/download")
|
||||
assert result == "report.csv"
|
||||
|
||||
def test_unquoted_filename(self) -> None:
|
||||
headers = {"content-disposition": "attachment; filename=report.csv"}
|
||||
result = extract_filename(headers, "https://example.com/download", 1)
|
||||
result = extract_filename(headers, "https://example.com/download")
|
||||
assert result == "report.csv"
|
||||
|
||||
def test_filename_star_takes_priority(self) -> None:
|
||||
headers = {
|
||||
"content-disposition": "attachment; filename=\"fallback.csv\"; filename*=UTF-8''preferred.csv",
|
||||
}
|
||||
result = extract_filename(headers, "https://example.com/download", 1)
|
||||
result = extract_filename(headers, "https://example.com/download")
|
||||
assert result == "preferred.csv"
|
||||
|
||||
def test_url_path_fallback(self) -> None:
|
||||
headers: dict[str, str] = {}
|
||||
result = extract_filename(headers, "https://example.com/files/document.pdf", 1)
|
||||
result = extract_filename(headers, "https://example.com/files/document.pdf")
|
||||
assert result == "document.pdf"
|
||||
|
||||
def test_url_path_with_encoded_chars(self) -> None:
|
||||
headers: dict[str, str] = {}
|
||||
result = extract_filename(headers, "https://example.com/files/my%20report.xlsx", 1)
|
||||
result = extract_filename(headers, "https://example.com/files/my%20report.xlsx")
|
||||
assert result == "my report.xlsx"
|
||||
|
||||
def test_url_path_no_extension_uses_fallback(self) -> None:
|
||||
def test_url_path_no_extension_returns_empty(self) -> None:
|
||||
"""No extension in URL path and no Content-Disposition — returns empty string."""
|
||||
headers: dict[str, str] = {}
|
||||
result = extract_filename(headers, "https://example.com/download", 1)
|
||||
assert result.startswith("download_")
|
||||
result = extract_filename(headers, "https://example.com/download")
|
||||
assert result == ""
|
||||
|
||||
def test_fallback_format(self) -> None:
|
||||
headers: dict[str, str] = {}
|
||||
before = int(time.time())
|
||||
result = extract_filename(headers, "https://example.com/api/export", 42)
|
||||
after = int(time.time())
|
||||
# Should be download_{timestamp}_{index}
|
||||
parts = result.split("_")
|
||||
assert parts[0] == "download"
|
||||
assert before <= int(parts[1]) <= after
|
||||
assert parts[2] == "42"
|
||||
def test_no_headers_no_url_returns_empty(self) -> None:
|
||||
"""Completely empty inputs — returns empty string for _resolve_save_path to handle."""
|
||||
result = extract_filename({}, "https://example.com/api/export")
|
||||
assert result == ""
|
||||
|
||||
def test_empty_content_disposition(self) -> None:
|
||||
headers = {"content-disposition": ""}
|
||||
result = extract_filename(headers, "https://example.com/files/data.csv", 1)
|
||||
result = extract_filename(headers, "https://example.com/files/data.csv")
|
||||
assert result == "data.csv"
|
||||
|
||||
def test_content_disposition_inline(self) -> None:
|
||||
"""inline disposition without filename should fall back to URL."""
|
||||
headers = {"content-disposition": "inline"}
|
||||
result = extract_filename(headers, "https://example.com/files/report.pdf", 1)
|
||||
result = extract_filename(headers, "https://example.com/files/report.pdf")
|
||||
assert result == "report.pdf"
|
||||
|
||||
def test_path_traversal_stripped(self) -> None:
|
||||
"""Path traversal in filename should be sanitized to just the filename part."""
|
||||
def test_path_traversal_returned_raw(self) -> None:
|
||||
"""extract_filename returns raw name; sanitization is done in _resolve_save_path."""
|
||||
headers = {"content-disposition": 'attachment; filename="../../etc/cron.d/evil"'}
|
||||
result = extract_filename(headers, "https://example.com/download", 1)
|
||||
# extract_filename returns the raw name; sanitization is done in _handle_download.
|
||||
# But verify the raw output so tests document the behavior.
|
||||
result = extract_filename(headers, "https://example.com/download")
|
||||
assert result == "../../etc/cron.d/evil"
|
||||
|
||||
|
||||
class TestResolveSavePath:
|
||||
"""Tests for CDPDownloadInterceptor._resolve_save_path()."""
|
||||
|
||||
def _make_interceptor(self, tmp_path: Path) -> CDPDownloadInterceptor:
|
||||
interceptor = CDPDownloadInterceptor(output_dir=str(tmp_path))
|
||||
return interceptor
|
||||
|
||||
def test_normal_filename(self, tmp_path: Path) -> None:
|
||||
interceptor = self._make_interceptor(tmp_path)
|
||||
save_path, filename = interceptor._resolve_save_path("report.pdf")
|
||||
assert filename == "report.pdf"
|
||||
assert save_path == tmp_path / "report.pdf"
|
||||
|
||||
def test_empty_filename_gets_uuid_fallback(self, tmp_path: Path) -> None:
|
||||
"""Empty filename should generate a download_{uuid} fallback."""
|
||||
interceptor = self._make_interceptor(tmp_path)
|
||||
save_path, filename = interceptor._resolve_save_path("")
|
||||
assert filename.startswith("download_")
|
||||
assert len(filename) > len("download_")
|
||||
assert save_path == tmp_path / filename
|
||||
|
||||
def test_default_param_empty_string(self, tmp_path: Path) -> None:
|
||||
"""Calling without arguments should also trigger fallback."""
|
||||
interceptor = self._make_interceptor(tmp_path)
|
||||
_, filename = interceptor._resolve_save_path()
|
||||
assert filename.startswith("download_")
|
||||
|
||||
def test_path_traversal_sanitized(self, tmp_path: Path) -> None:
|
||||
"""Path traversal components should be stripped — only the final name is kept."""
|
||||
interceptor = self._make_interceptor(tmp_path)
|
||||
save_path, filename = interceptor._resolve_save_path("../../etc/cron.d/evil")
|
||||
assert filename == "evil"
|
||||
assert save_path == tmp_path / "evil"
|
||||
|
||||
def test_increments_download_index(self, tmp_path: Path) -> None:
|
||||
interceptor = self._make_interceptor(tmp_path)
|
||||
assert interceptor._download_index == 0
|
||||
interceptor._resolve_save_path("a.pdf")
|
||||
assert interceptor._download_index == 1
|
||||
interceptor._resolve_save_path("b.pdf")
|
||||
assert interceptor._download_index == 2
|
||||
|
||||
def test_collision_warning_logged(self, tmp_path: Path) -> None:
|
||||
"""Existing file with the same name should warn but still return the path."""
|
||||
interceptor = self._make_interceptor(tmp_path)
|
||||
# Create a file that will collide
|
||||
(tmp_path / "report.pdf").write_bytes(b"existing")
|
||||
save_path, filename = interceptor._resolve_save_path("report.pdf")
|
||||
assert filename == "report.pdf"
|
||||
assert save_path == tmp_path / "report.pdf"
|
||||
|
||||
def test_creates_output_dir_if_missing(self, tmp_path: Path) -> None:
|
||||
nested = tmp_path / "sub" / "dir"
|
||||
interceptor = CDPDownloadInterceptor(output_dir=str(nested))
|
||||
save_path, _ = interceptor._resolve_save_path("file.txt")
|
||||
assert nested.exists()
|
||||
assert save_path.parent == nested
|
||||
|
||||
|
||||
class TestCDPDownloadInterceptorProxyAuth:
|
||||
"""Tests for CDP proxy authentication handling (Fetch.authRequired + continueWithAuth)."""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue