mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-04-28 03:30:23 +00:00
Restore remote document fetch compatibility for public sites after the CVE-2026-4308 SSRF hardening. The initial security fix correctly blocked non-public destinations, but it also changed the outbound request fingerprint for `document_query` remote fetches. Some public sites, including https://nvd.nist.gov/vuln/detail/CVE-2026-4308, used for testing, responded with HTTP 403 to the default `requests` user agent even though they remained safe and publicly routable. This change keeps the centralized SSRF protections in place while restoring the previous request compatibility behavior by sending the configured `USER_AGENT` header, falling back to the prior `@mixedbread-ai/unstructured` value. What is fixed: - public URLs such as `https://nvd.nist.gov/vuln/detail/CVE-2026-4308` no longer fail with site-specific HTTP 403 due to request fingerprint changes introduced by the SSRF mitigation
202 lines
6.6 KiB
Python
202 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import ipaddress
|
|
import os
|
|
import socket
|
|
import struct
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
import requests
|
|
|
|
|
|
SAFE_HTTP_SCHEMES = frozenset({"http", "https"})
|
|
DEFAULT_FETCH_TIMEOUT = (3.05, 10.0)
|
|
DEFAULT_HTTP_USER_AGENT = "@mixedbread-ai/unstructured"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HttpFetchResult:
|
|
url: str
|
|
content: bytes
|
|
content_type: str | None
|
|
encoding: str | None
|
|
|
|
|
|
class UnsafeUrlError(ValueError):
|
|
"""Raised when a remote URL resolves to a non-public destination."""
|
|
|
|
|
|
def _build_request_headers() -> dict[str, str]:
|
|
user_agent = (
|
|
os.getenv("USER_AGENT")
|
|
or os.getenv("user_agent")
|
|
or DEFAULT_HTTP_USER_AGENT
|
|
).strip()
|
|
return {"User-Agent": user_agent or DEFAULT_HTTP_USER_AGENT}
|
|
|
|
|
|
def _normalize_content_type(content_type: str | None) -> str | None:
|
|
if not content_type:
|
|
return None
|
|
return content_type.split(";", 1)[0].strip().lower() or None
|
|
|
|
|
|
def resolve_host_ips(hostname: str) -> tuple[ipaddress._BaseAddress, ...]:
|
|
try:
|
|
results = socket.getaddrinfo(
|
|
hostname,
|
|
None,
|
|
family=socket.AF_UNSPEC,
|
|
type=socket.SOCK_STREAM,
|
|
)
|
|
except socket.gaierror as exc:
|
|
raise UnsafeUrlError(f"Unable to resolve hostname '{hostname}'") from exc
|
|
|
|
ips: list[ipaddress._BaseAddress] = []
|
|
seen: set[str] = set()
|
|
for _family, _type, _proto, _canonname, sockaddr in results:
|
|
address = sockaddr[0]
|
|
if "%" in address:
|
|
address = address.split("%", 1)[0]
|
|
ip = ipaddress.ip_address(address)
|
|
key = ip.compressed
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
ips.append(ip)
|
|
|
|
if not ips:
|
|
raise UnsafeUrlError(f"Hostname '{hostname}' did not resolve to an IP address")
|
|
|
|
return tuple(ips)
|
|
|
|
|
|
def validate_public_http_url(url: str) -> tuple[ipaddress._BaseAddress, ...]:
|
|
parsed = urlparse(url)
|
|
|
|
if parsed.scheme not in SAFE_HTTP_SCHEMES:
|
|
raise UnsafeUrlError("Only http:// and https:// URLs are supported")
|
|
if not parsed.hostname:
|
|
raise UnsafeUrlError("URL hostname is required")
|
|
if parsed.username or parsed.password:
|
|
raise UnsafeUrlError("URLs with embedded credentials are not allowed")
|
|
|
|
hostname = parsed.hostname.rstrip(".").lower()
|
|
if hostname == "localhost" or hostname.endswith(".localhost"):
|
|
raise UnsafeUrlError(f"Blocked local hostname '{hostname}'")
|
|
|
|
ips = resolve_host_ips(hostname)
|
|
blocked = [str(ip) for ip in ips if not ip.is_global]
|
|
if blocked:
|
|
raise UnsafeUrlError(
|
|
f"Blocked non-public address resolution for '{hostname}': {', '.join(blocked)}"
|
|
)
|
|
|
|
return ips
|
|
|
|
|
|
def fetch_public_http_resource(
|
|
url: str,
|
|
*,
|
|
max_bytes: int,
|
|
max_redirects: int = 5,
|
|
timeout: tuple[float, float] = DEFAULT_FETCH_TIMEOUT,
|
|
) -> HttpFetchResult:
|
|
current_url = url
|
|
session = requests.Session()
|
|
session.trust_env = False
|
|
|
|
for redirect_count in range(max_redirects + 1):
|
|
validate_public_http_url(current_url)
|
|
|
|
try:
|
|
with session.get(
|
|
current_url,
|
|
stream=True,
|
|
allow_redirects=False,
|
|
headers=_build_request_headers(),
|
|
timeout=timeout,
|
|
) as response:
|
|
if 300 <= response.status_code < 400:
|
|
location = response.headers.get("Location")
|
|
if not location:
|
|
raise ValueError(
|
|
f"Remote URL redirect is missing a Location header: {current_url}"
|
|
)
|
|
if redirect_count >= max_redirects:
|
|
raise ValueError(
|
|
f"Remote URL exceeded redirect limit ({max_redirects}): {url}"
|
|
)
|
|
current_url = urljoin(current_url, location)
|
|
continue
|
|
|
|
if response.status_code >= 400:
|
|
raise ValueError(
|
|
f"Remote URL returned HTTP {response.status_code}: {current_url}"
|
|
)
|
|
|
|
content_length = response.headers.get("Content-Length")
|
|
if content_length:
|
|
try:
|
|
declared_length = int(content_length)
|
|
except ValueError:
|
|
declared_length = None
|
|
if declared_length is not None and declared_length > max_bytes:
|
|
raise ValueError(
|
|
f"Remote document exceeds max size {max_bytes} bytes: {current_url}"
|
|
)
|
|
|
|
body = bytearray()
|
|
for chunk in response.iter_content(chunk_size=64 * 1024):
|
|
if not chunk:
|
|
continue
|
|
body.extend(chunk)
|
|
if len(body) > max_bytes:
|
|
raise ValueError(
|
|
f"Remote document exceeds max size {max_bytes} bytes: {current_url}"
|
|
)
|
|
|
|
return HttpFetchResult(
|
|
url=current_url,
|
|
content=bytes(body),
|
|
content_type=_normalize_content_type(
|
|
response.headers.get("Content-Type")
|
|
),
|
|
encoding=response.encoding,
|
|
)
|
|
except requests.RequestException as exc:
|
|
raise ValueError(
|
|
f"Remote document fetch failed for {current_url}: {exc}"
|
|
) from exc
|
|
|
|
raise ValueError(f"Remote URL exceeded redirect limit ({max_redirects}): {url}")
|
|
|
|
|
|
def is_loopback_address(address: str) -> bool:
|
|
"""Check whether *address* resolves to a loopback interface."""
|
|
_checkers = {
|
|
socket.AF_INET: lambda x: (
|
|
struct.unpack("!I", socket.inet_aton(x))[0] >> (32 - 8)
|
|
) == 127,
|
|
socket.AF_INET6: lambda x: x == "::1",
|
|
}
|
|
try:
|
|
socket.inet_pton(socket.AF_INET6, address)
|
|
return _checkers[socket.AF_INET6](address)
|
|
except socket.error:
|
|
pass
|
|
try:
|
|
socket.inet_pton(socket.AF_INET, address)
|
|
return _checkers[socket.AF_INET](address)
|
|
except socket.error:
|
|
pass
|
|
for family in (socket.AF_INET, socket.AF_INET6):
|
|
try:
|
|
r = socket.getaddrinfo(address, None, family, socket.SOCK_STREAM)
|
|
except socket.gaierror:
|
|
return False
|
|
for fam, _, _, _, sockaddr in r:
|
|
if not _checkers[fam](sockaddr[0]):
|
|
return False
|
|
return True
|