free-claude-code/api/web_tools/egress.py
Alishahryar1 f3a7528d49
Some checks are pending
CI / checks (push) Waiting to run
Major refactor: API, providers, messaging, and Anthropic protocol
Consolidates the incremental refactor work into a single change set: modular web tools (api/web_tools), native Anthropic request building and SSE block policy, OpenAI conversion and error handling, provider transports and rate limiting, messaging handler and tree queue, safe logging, smoke tests, and broad test coverage.
2026-04-26 03:01:14 -07:00

99 lines
3.2 KiB
Python

"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
@dataclass(frozen=True, slots=True)
class WebFetchEgressPolicy:
"""Egress rules for user-influenced web_fetch URLs."""
allow_private_network_targets: bool
allowed_schemes: frozenset[str]
class WebFetchEgressViolation(ValueError):
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
def _port_for_url(parsed) -> int:
if parsed.port is not None:
return parsed.port
return 443 if (parsed.scheme or "").lower() == "https" else 80
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
try:
return socket.getaddrinfo(
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
)
except OSError as exc:
raise WebFetchEgressViolation(
f"Could not resolve host {host!r}: {exc}"
) from exc
def get_validated_stream_addrinfos_for_egress(
url: str, policy: WebFetchEgressPolicy
) -> list[tuple]:
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
server cannot rebind to a disallowed address between resolution and the TCP
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
"""
parsed = urlparse(url)
scheme = (parsed.scheme or "").lower()
if scheme not in policy.allowed_schemes:
raise WebFetchEgressViolation(
f"URL scheme {scheme!r} is not allowed for web_fetch"
)
host = parsed.hostname
if host is None or host == "":
raise WebFetchEgressViolation("web_fetch URL must include a host")
port = _port_for_url(parsed)
if policy.allow_private_network_targets:
return _stream_getaddrinfo_or_raise(host, port)
host_lower = host.lower()
if host_lower == "localhost" or host_lower.endswith(".localhost"):
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
if host_lower.endswith(".local"):
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
try:
parsed_ip = ipaddress.ip_address(host)
except ValueError:
parsed_ip = None
if parsed_ip is not None:
if not parsed_ip.is_global:
raise WebFetchEgressViolation(
f"Non-public IP host {host!r} is not allowed for web_fetch"
)
return _stream_getaddrinfo_or_raise(host, port)
infos = _stream_getaddrinfo_or_raise(host, port)
for *_, sockaddr in infos:
addr = sockaddr[0]
try:
resolved = ipaddress.ip_address(addr)
except ValueError:
continue
if not resolved.is_global:
raise WebFetchEgressViolation(
f"Host {host!r} resolves to a non-public address ({resolved})"
)
return infos
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
get_validated_stream_addrinfos_for_egress(url, policy)