fix: add browser-level download monitor for CDP downloads bypassing Fetch (#5089)

This commit is contained in:
LawyZheng 2026-03-13 20:23:08 +08:00 committed by GitHub
parent cece22f21b
commit e80ded3a97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 298 additions and 50 deletions

View file

@ -21,13 +21,16 @@ from __future__ import annotations
import asyncio
import base64
import re
import ssl
import time
import urllib.request
import uuid
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse
import structlog
from playwright.async_api import CDPSession, Page
from playwright.async_api import Browser, BrowserContext, CDPSession, Page
LOG = structlog.get_logger()
@ -161,7 +164,7 @@ def is_download_response(headers: dict[str, str], status_code: int, resource_typ
return False
def extract_filename(headers: dict[str, str], url: str, index: int) -> str:
def extract_filename(headers: dict[str, str], url: str) -> str:
"""
Extract filename from response headers or URL.
@ -169,7 +172,7 @@ def extract_filename(headers: dict[str, str], url: str, index: int) -> str:
1. Content-Disposition filename*= (RFC 5987, UTF-8)
2. Content-Disposition filename=
3. URL path last segment (if it has an extension)
4. Fallback: download_{timestamp}_{index}
4. Empty string (caller is responsible for fallback via _resolve_save_path)
"""
content_disposition = headers.get("content-disposition", "")
@ -192,7 +195,7 @@ def extract_filename(headers: dict[str, str], url: str, index: int) -> str:
if "." in last_segment:
return last_segment
return f"download_{int(time.time())}_{index}"
return ""
class CDPDownloadInterceptor:
@ -223,6 +226,10 @@ class CDPDownloadInterceptor:
# Track auth attempts per requestId to prevent infinite retry loops
# when proxy credentials are rejected (407 → ProvideCredentials → 407 → …)
self._auth_attempts: dict[str, int] = {}
# Track URLs already downloaded (dedup between Fetch interception and browser download monitor)
self._downloaded_urls: set[str] = set()
self._browser_session: CDPSession | None = None
self._browser_context: BrowserContext | None = None
def set_download_dir(self, download_dir: str) -> None:
"""Set or update the download directory. Can be called after init when run_id becomes available."""
@ -230,6 +237,31 @@ class CDPDownloadInterceptor:
self._output_dir.mkdir(parents=True, exist_ok=True)
LOG.info("CDP download interceptor download dir set", download_dir=download_dir)
def _resolve_save_path(self, filename: str = "") -> tuple[Path, str]:
"""Generate a unique save path under _output_dir.
Sanitizes the filename (path traversal prevention), falls back to a UUID-based
name when empty, increments _download_index, and logs a warning if a file with
the same name already exists. Returns (save_path, sanitized_filename).
Callers can pass a raw or empty filename this method handles all normalization.
"""
assert self._output_dir is not None
self._output_dir.mkdir(parents=True, exist_ok=True)
self._download_index += 1
# Sanitize to prevent path traversal (e.g. "../../etc/evil")
filename = Path(filename).name
if not filename:
filename = f"download_{uuid.uuid4().hex[:8]}"
save_path = self._output_dir / filename
# TODO: implement proper filename dedup (e.g., content hash or UUID suffix)
if save_path.exists():
LOG.warning("Download filename collision, overwriting", filename=filename, save_path=str(save_path))
return save_path, filename
async def enable_for_page(self, page: Page) -> None:
"""Create a CDP session for the given page and enable Fetch interception.
@ -274,8 +306,157 @@ class CDPDownloadInterceptor:
proxy_auth_enabled=has_proxy_auth,
)
async def enable_browser_download_monitor(self, browser: Browser, browser_context: BrowserContext) -> None:
"""Monitor browser-initiated downloads and save them directly via HTTP.
Many sites trigger downloads via mechanisms that bypass CDP Fetch
(e.g., new tab for signed URL, <a download>, blob URLs). The browser's
download manager handles these directly no page-level network request occurs.
This method uses Browser-level CDP events to detect such downloads,
then downloads the file directly via HTTP using the BrowserContext's
APIRequestContext (which shares cookies and outlives individual pages).
"""
if self._browser_session is not None:
LOG.warning("Browser download monitor already enabled, skipping")
return
browser_session = await browser.new_browser_cdp_session()
self._browser_session = browser_session
self._browser_context = browser_context
# Deny browser-native downloads — we download files ourselves via HTTP.
# Using "deny" instead of "allowAndName" avoids needing a downloadPath, which is
# critical for remote CDP browsers: downloadPath is interpreted on the browser's
# filesystem, not the client's, so a local tempdir path would be invalid.
# Browser.downloadWillBegin events still fire with eventsEnabled=True, giving us
# the URL to download directly.
await browser_session.send(
"Browser.setDownloadBehavior",
{"behavior": "deny", "eventsEnabled": True},
)
browser_session.on(
"Browser.downloadWillBegin",
lambda event: asyncio.ensure_future(self._handle_browser_download(event)),
)
LOG.info("Browser download monitor enabled")
async def _handle_browser_download(self, event: dict[str, Any]) -> None:
"""Handle Browser.downloadWillBegin — download the file via HTTP or blob read."""
try:
url = event.get("url", "")
suggested_filename = event.get("suggestedFilename", "")
LOG.info(
"Browser download detected",
url=url,
suggested_filename=suggested_filename,
)
if not url:
LOG.warning("Empty download URL, skipping")
return
# Skip if already downloaded via CDP Fetch interception.
# Fetch always fires before Browser.downloadWillBegin (Fetch intercepts at
# response stage, browser download manager fires after fulfillRequest), so
# this check is purely one-directional: only _handle_download writes the set.
if url in self._downloaded_urls:
LOG.debug("URL already captured via Fetch, skipping direct download", url=url)
return
if url.startswith("blob:"):
# blob: URLs are in-memory browser references — not fetchable over HTTP.
# They are already handled by the Fetch path which intercepts the resolved blob.
# TODO: handle the edge case where Fetch doesn't catch a blob download.
LOG.warning(
"blob: URL download not yet supported, skipping", url=url, suggested_filename=suggested_filename
)
return
elif url.startswith("http"):
await self._download_url_directly(url, suggested_filename)
else:
LOG.warning("Download URL scheme not supported, skipping", url=url)
except Exception:
LOG.warning("Error handling browser download event", exc_info=True)
async def _download_url_directly(self, url: str, suggested_filename: str) -> None:
"""Download a URL directly via HTTP and save to the output directory.
Tries Playwright's APIRequestContext first (shares browser context cookies),
falls back to urllib for pre-signed URLs or when APIRequestContext fails.
"""
if not self._output_dir:
LOG.warning("No output_dir set, skipping direct download", url=url)
return
save_path, filename = self._resolve_save_path(suggested_filename)
t0 = time.monotonic()
data: bytes | None = None
method = ""
# Try Playwright's APIRequestContext which shares the BrowserContext's cookies.
# We use the BrowserContext (not a Page) so this survives individual page closes.
if self._browser_context:
try:
response = await self._browser_context.request.get(url)
if response.ok:
data = await response.body()
method = "playwright_api"
else:
LOG.debug(
"Playwright APIRequestContext returned non-OK status, trying urllib",
url=url,
status=response.status,
)
except Exception as e:
LOG.debug("Playwright APIRequestContext download failed, trying urllib", url=url, error=str(e))
# Fallback: direct HTTP via urllib (works for pre-signed URLs)
if data is None:
try:
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
ssl_ctx = ssl.create_default_context()
def _fetch() -> bytes:
with urllib.request.urlopen(req, context=ssl_ctx) as resp:
return resp.read()
data = await asyncio.to_thread(_fetch)
method = "urllib"
except Exception as e:
LOG.error("Direct HTTP download failed", url=url, error=str(e), exc_info=True)
return
if data is None:
LOG.error("Download produced no data", url=url)
return
if len(data) > MAX_FILE_SIZE_BYTES:
LOG.warning(
"Direct download exceeds size limit, discarding",
url=url,
size=len(data),
max_size=MAX_FILE_SIZE_BYTES,
)
return
with open(save_path, "wb") as f:
f.write(data)
elapsed_ms = (time.monotonic() - t0) * 1000
LOG.info(
"CDP download saved (direct HTTP)",
filename=filename,
size=len(data),
duration_ms=round(elapsed_ms, 1),
save_path=str(save_path),
download_index=self._download_index,
method=method,
)
async def disable(self) -> None:
"""Disable Fetch interception on all CDP sessions."""
"""Disable Fetch interception on all CDP sessions and clean up browser monitor."""
session_count = len(self._cdp_sessions)
for cdp_session in self._cdp_sessions:
try:
@ -283,6 +464,16 @@ class CDPDownloadInterceptor:
except Exception:
pass
self._cdp_sessions.clear()
# Clean up browser-level download monitor session
if self._browser_session:
try:
await self._browser_session.detach()
except Exception:
pass
self._browser_session = None
self._browser_context = None
self._enabled = False
LOG.info(
"CDP Fetch interception disabled",
@ -387,6 +578,15 @@ class CDPDownloadInterceptor:
response_headers = _parse_headers(raw_response_headers)
resource_type = event.get("resourceType", "")
LOG.debug(
"CDP Fetch response paused",
url=url,
resource_type=resource_type,
status_code=response_status,
content_type=response_headers.get("content-type", ""),
content_disposition=response_headers.get("content-disposition", ""),
)
if is_download_response(response_headers, response_status, resource_type):
LOG.info(
"CDP download response detected",
@ -439,24 +639,10 @@ class CDPDownloadInterceptor:
await self._continue_response(cdp_session, request_id)
return
self._output_dir.mkdir(parents=True, exist_ok=True)
self._download_index += 1
content_length = _parse_content_length(headers)
content_type = headers.get("content-type", "").split(";")[0].strip()
filename = extract_filename(headers, url, self._download_index)
# Sanitize filename to prevent path traversal (e.g. "../../etc/evil")
filename = Path(filename).name
if not filename:
filename = f"download_{int(time.time())}_{self._download_index}"
save_path = self._output_dir / filename
# Deduplicate filename if a file with the same name already exists
if save_path.exists():
stem = Path(filename).stem
suffix = Path(filename).suffix
filename = f"{stem}_{self._download_index}{suffix}"
save_path = self._output_dir / filename
raw_filename = extract_filename(headers, url)
save_path, filename = self._resolve_save_path(raw_filename)
LOG.info(
"CDP download detected",
@ -476,6 +662,13 @@ class CDPDownloadInterceptor:
await self._continue_response(cdp_session, request_id)
return
# Mark URL as handled BEFORE starting the (potentially slow) body extraction.
# This prevents the browser download monitor (_handle_browser_download) from
# racing to download the same URL while we're still streaming the body.
# We intentionally do NOT remove the URL on failure — if Fetch extraction fails,
# a direct HTTP re-download of the same URL would likely fail too.
self._downloaded_urls.add(url)
t0 = time.monotonic()
try:
@ -497,7 +690,6 @@ class CDPDownloadInterceptor:
with open(save_path, "wb") as f:
f.write(data)
elapsed_ms = (time.monotonic() - t0) * 1000
LOG.info(
"CDP download saved",

View file

@ -1,6 +1,6 @@
"""Unit tests for CDPDownloadInterceptor pure functions and proxy auth handling."""
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -160,76 +160,132 @@ class TestIsDownloadResponse:
class TestExtractFilename:
"""Tests for extract_filename()."""
"""Tests for extract_filename().
extract_filename returns an empty string when no filename can be determined
the caller (_resolve_save_path) is responsible for generating a fallback name.
"""
def test_rfc5987_filename_star(self) -> None:
headers = {"content-disposition": "attachment; filename*=UTF-8''my%20report%282024%29.pdf"}
result = extract_filename(headers, "https://example.com/download", 1)
result = extract_filename(headers, "https://example.com/download")
assert result == "my report(2024).pdf"
def test_regular_filename(self) -> None:
headers = {"content-disposition": 'attachment; filename="report.csv"'}
result = extract_filename(headers, "https://example.com/download", 1)
result = extract_filename(headers, "https://example.com/download")
assert result == "report.csv"
def test_unquoted_filename(self) -> None:
headers = {"content-disposition": "attachment; filename=report.csv"}
result = extract_filename(headers, "https://example.com/download", 1)
result = extract_filename(headers, "https://example.com/download")
assert result == "report.csv"
def test_filename_star_takes_priority(self) -> None:
headers = {
"content-disposition": "attachment; filename=\"fallback.csv\"; filename*=UTF-8''preferred.csv",
}
result = extract_filename(headers, "https://example.com/download", 1)
result = extract_filename(headers, "https://example.com/download")
assert result == "preferred.csv"
def test_url_path_fallback(self) -> None:
headers: dict[str, str] = {}
result = extract_filename(headers, "https://example.com/files/document.pdf", 1)
result = extract_filename(headers, "https://example.com/files/document.pdf")
assert result == "document.pdf"
def test_url_path_with_encoded_chars(self) -> None:
headers: dict[str, str] = {}
result = extract_filename(headers, "https://example.com/files/my%20report.xlsx", 1)
result = extract_filename(headers, "https://example.com/files/my%20report.xlsx")
assert result == "my report.xlsx"
def test_url_path_no_extension_uses_fallback(self) -> None:
def test_url_path_no_extension_returns_empty(self) -> None:
"""No extension in URL path and no Content-Disposition — returns empty string."""
headers: dict[str, str] = {}
result = extract_filename(headers, "https://example.com/download", 1)
assert result.startswith("download_")
result = extract_filename(headers, "https://example.com/download")
assert result == ""
def test_fallback_format(self) -> None:
headers: dict[str, str] = {}
before = int(time.time())
result = extract_filename(headers, "https://example.com/api/export", 42)
after = int(time.time())
# Should be download_{timestamp}_{index}
parts = result.split("_")
assert parts[0] == "download"
assert before <= int(parts[1]) <= after
assert parts[2] == "42"
def test_no_headers_no_url_returns_empty(self) -> None:
"""Completely empty inputs — returns empty string for _resolve_save_path to handle."""
result = extract_filename({}, "https://example.com/api/export")
assert result == ""
def test_empty_content_disposition(self) -> None:
headers = {"content-disposition": ""}
result = extract_filename(headers, "https://example.com/files/data.csv", 1)
result = extract_filename(headers, "https://example.com/files/data.csv")
assert result == "data.csv"
def test_content_disposition_inline(self) -> None:
"""inline disposition without filename should fall back to URL."""
headers = {"content-disposition": "inline"}
result = extract_filename(headers, "https://example.com/files/report.pdf", 1)
result = extract_filename(headers, "https://example.com/files/report.pdf")
assert result == "report.pdf"
def test_path_traversal_stripped(self) -> None:
"""Path traversal in filename should be sanitized to just the filename part."""
def test_path_traversal_returned_raw(self) -> None:
"""extract_filename returns raw name; sanitization is done in _resolve_save_path."""
headers = {"content-disposition": 'attachment; filename="../../etc/cron.d/evil"'}
result = extract_filename(headers, "https://example.com/download", 1)
# extract_filename returns the raw name; sanitization is done in _handle_download.
# But verify the raw output so tests document the behavior.
result = extract_filename(headers, "https://example.com/download")
assert result == "../../etc/cron.d/evil"
class TestResolveSavePath:
"""Tests for CDPDownloadInterceptor._resolve_save_path()."""
def _make_interceptor(self, tmp_path: Path) -> CDPDownloadInterceptor:
interceptor = CDPDownloadInterceptor(output_dir=str(tmp_path))
return interceptor
def test_normal_filename(self, tmp_path: Path) -> None:
interceptor = self._make_interceptor(tmp_path)
save_path, filename = interceptor._resolve_save_path("report.pdf")
assert filename == "report.pdf"
assert save_path == tmp_path / "report.pdf"
def test_empty_filename_gets_uuid_fallback(self, tmp_path: Path) -> None:
"""Empty filename should generate a download_{uuid} fallback."""
interceptor = self._make_interceptor(tmp_path)
save_path, filename = interceptor._resolve_save_path("")
assert filename.startswith("download_")
assert len(filename) > len("download_")
assert save_path == tmp_path / filename
def test_default_param_empty_string(self, tmp_path: Path) -> None:
"""Calling without arguments should also trigger fallback."""
interceptor = self._make_interceptor(tmp_path)
_, filename = interceptor._resolve_save_path()
assert filename.startswith("download_")
def test_path_traversal_sanitized(self, tmp_path: Path) -> None:
"""Path traversal components should be stripped — only the final name is kept."""
interceptor = self._make_interceptor(tmp_path)
save_path, filename = interceptor._resolve_save_path("../../etc/cron.d/evil")
assert filename == "evil"
assert save_path == tmp_path / "evil"
def test_increments_download_index(self, tmp_path: Path) -> None:
interceptor = self._make_interceptor(tmp_path)
assert interceptor._download_index == 0
interceptor._resolve_save_path("a.pdf")
assert interceptor._download_index == 1
interceptor._resolve_save_path("b.pdf")
assert interceptor._download_index == 2
def test_collision_warning_logged(self, tmp_path: Path) -> None:
"""Existing file with the same name should warn but still return the path."""
interceptor = self._make_interceptor(tmp_path)
# Create a file that will collide
(tmp_path / "report.pdf").write_bytes(b"existing")
save_path, filename = interceptor._resolve_save_path("report.pdf")
assert filename == "report.pdf"
assert save_path == tmp_path / "report.pdf"
def test_creates_output_dir_if_missing(self, tmp_path: Path) -> None:
nested = tmp_path / "sub" / "dir"
interceptor = CDPDownloadInterceptor(output_dir=str(nested))
save_path, _ = interceptor._resolve_save_path("file.txt")
assert nested.exists()
assert save_path.parent == nested
class TestCDPDownloadInterceptorProxyAuth:
"""Tests for CDP proxy authentication handling (Fetch.authRequired + continueWithAuth)."""