Fix SSRF in document_query remote fetching (CVE-2026-4308)

Address CVE-2026-4308 in the document_query tool remote-fetch path.

The issue was originally reported by @YLChen-007.

This change replaces ad hoc remote document fetching with a centralized
safe fetch flow that validates remote URLs before any network request is
used for parsing. It blocks localhost and non-public IPv4/IPv6 targets,
validates every redirect hop, disables implicit trust of proxy env
settings for this path, and enforces a strict remote document size cap.

It also removes direct third-party loader access to attacker-controlled
URLs by prefetching remote content first and then parsing only trusted
local bytes or temp files for HTML, text, PDF, image, and unstructured
document handling.

Refs:
- CVE-2026-4308
- Report by @YLChen-007
This commit is contained in:
Alessandro 2026-04-12 02:00:01 +02:00
parent 071194281c
commit 6397acc092
2 changed files with 279 additions and 73 deletions

View file

@ -1,7 +1,6 @@
import mimetypes
import os
import asyncio
import aiohttp
import json
from helpers.vector_db import VectorDB
@ -13,8 +12,6 @@ from urllib.parse import urlparse
from typing import Callable, Sequence, List, Optional, Tuple
from datetime import datetime
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.pdf import PyMuPDFLoader
from langchain_community.document_transformers import MarkdownifyTransformer
from langchain_community.document_loaders.parsers.images import TesseractBlobParser
@ -24,12 +21,14 @@ from langchain.schema import SystemMessage, HumanMessage
from helpers.print_style import PrintStyle
from helpers import files, errors
from helpers.network import HttpFetchResult, fetch_public_http_resource
from agent import Agent
from langchain.text_splitter import RecursiveCharacterTextSplitter
DEFAULT_SEARCH_THRESHOLD = 0.5
MAX_REMOTE_DOCUMENT_BYTES = 50 * 1024 * 1024
class DocumentQueryStore:
@ -450,45 +449,19 @@ class DocumentQueryHelper:
scheme = url.scheme or "file"
mimetype, encoding = mimetypes.guess_type(document_uri)
mimetype = mimetype or "application/octet-stream"
remote_resource: HttpFetchResult | None = None
if mimetype == "application/octet-stream":
if url.scheme in ["http", "https"]:
response: aiohttp.ClientResponse | None = None
retries = 0
last_error = ""
while not response and retries < 3:
try:
async with aiohttp.ClientSession() as session:
response = await session.head(
document_uri,
timeout=aiohttp.ClientTimeout(total=2.0),
allow_redirects=True,
)
if response.status > 399:
raise Exception(response.status)
break
except Exception as e:
await asyncio.sleep(1)
last_error = str(e)
retries += 1
await self.agent.handle_intervention()
if not response:
raise ValueError(
f"DocumentQueryHelper::document_get_content: Document fetch error: {document_uri} ({last_error})"
)
mimetype = response.headers["content-type"]
if "content-length" in response.headers:
content_length = (
float(response.headers["content-length"]) / 1024 / 1024
) # MB
if content_length > 50.0:
raise ValueError(
f"Document content length exceeds max. 50MB: {content_length} MB ({document_uri})"
)
if mimetype and "; charset=" in mimetype:
mimetype = mimetype.split("; charset=")[0]
if scheme in ["http", "https"]:
remote_resource = await asyncio.to_thread(
fetch_public_http_resource,
document_uri,
max_bytes=MAX_REMOTE_DOCUMENT_BYTES,
)
if (
remote_resource.content_type
and remote_resource.content_type != "application/octet-stream"
):
mimetype = remote_resource.content_type
if scheme == "file":
try:
@ -515,16 +488,24 @@ class DocumentQueryHelper:
if not exists:
await self.agent.handle_intervention()
if mimetype.startswith("image/"):
document_content = self.handle_image_document(document_uri, scheme)
document_content = self.handle_image_document(
document_uri, scheme, remote_resource=remote_resource
)
elif mimetype == "text/html":
document_content = self.handle_html_document(document_uri, scheme)
document_content = self.handle_html_document(
document_uri, scheme, remote_resource=remote_resource
)
elif mimetype.startswith("text/") or mimetype == "application/json":
document_content = self.handle_text_document(document_uri, scheme)
document_content = self.handle_text_document(
document_uri, scheme, remote_resource=remote_resource
)
elif mimetype == "application/pdf":
document_content = self.handle_pdf_document(document_uri, scheme)
document_content = self.handle_pdf_document(
document_uri, scheme, remote_resource=remote_resource
)
else:
document_content = self.handle_unstructured_document(
document_uri, scheme
document_uri, scheme, remote_resource=remote_resource
)
if add_to_db:
self.progress_callback(f"Indexing document")
@ -550,13 +531,53 @@ class DocumentQueryHelper:
)
return document_content
def handle_image_document(self, document: str, scheme: str) -> str:
return self.handle_unstructured_document(document, scheme)
@staticmethod
def _decode_remote_text(remote_resource: HttpFetchResult) -> str:
encoding = remote_resource.encoding or "utf-8"
try:
return remote_resource.content.decode(encoding)
except (LookupError, UnicodeDecodeError):
return remote_resource.content.decode("utf-8", errors="replace")
def handle_html_document(self, document: str, scheme: str) -> str:
@staticmethod
def _get_temp_file_suffix(
document: str, remote_resource: HttpFetchResult | None = None
) -> str:
parsed = urlparse(document)
_stem, ext = os.path.splitext(parsed.path or document)
if ext:
return ext
if remote_resource and remote_resource.content_type:
guessed_ext = mimetypes.guess_extension(
remote_resource.content_type, strict=False
)
if guessed_ext:
return guessed_ext
return ".bin"
def handle_image_document(
self,
document: str,
scheme: str,
remote_resource: HttpFetchResult | None = None,
) -> str:
return self.handle_unstructured_document(
document, scheme, remote_resource=remote_resource
)
def handle_html_document(
self,
document: str,
scheme: str,
remote_resource: HttpFetchResult | None = None,
) -> str:
if scheme in ["http", "https"]:
loader = AsyncHtmlLoader(web_path=document)
parts: list[Document] = loader.load()
if remote_resource is None:
raise ValueError("Missing prefetched remote HTML content")
html_content = self._decode_remote_text(remote_resource)
parts = [Document(page_content=html_content, metadata={"source": document})]
elif scheme == "file":
# Use RFC file operations instead of TextLoader
file_content_bytes = files.read_file_bin(document)
@ -573,10 +594,19 @@ class DocumentQueryHelper:
]
)
def handle_text_document(self, document: str, scheme: str) -> str:
def handle_text_document(
self,
document: str,
scheme: str,
remote_resource: HttpFetchResult | None = None,
) -> str:
if scheme in ["http", "https"]:
loader = AsyncHtmlLoader(web_path=document)
elements: list[Document] = loader.load()
if remote_resource is None:
raise ValueError("Missing prefetched remote text content")
file_content = self._decode_remote_text(remote_resource)
elements = [
Document(page_content=file_content, metadata={"source": document})
]
elif scheme == "file":
# Use RFC file operations instead of TextLoader
file_content_bytes = files.read_file_bin(document)
@ -590,7 +620,12 @@ class DocumentQueryHelper:
return "\n".join([element.page_content for element in elements])
def handle_pdf_document(self, document: str, scheme: str) -> str:
def handle_pdf_document(
self,
document: str,
scheme: str,
remote_resource: HttpFetchResult | None = None,
) -> str:
temp_file_path = ""
if scheme == "file":
# Use RFC file operations to read the PDF file as binary
@ -602,17 +637,12 @@ class DocumentQueryHelper:
temp_file.write(file_content_bytes)
temp_file_path = temp_file.name
elif scheme in ["http", "https"]:
# download the file from the web url to a temporary file using python libraries for downloading
import requests
import tempfile
if remote_resource is None:
raise ValueError("Missing prefetched remote PDF content")
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
response = requests.get(document, timeout=10.0)
if response.status_code != 200:
raise ValueError(
f"DocumentQueryHelper::handle_pdf_document: Failed to download PDF from {document}: {response.status_code}"
)
temp_file.write(response.content)
temp_file.write(remote_resource.content)
temp_file_path = temp_file.name
else:
raise ValueError(f"Unsupported scheme: {scheme}")
@ -658,18 +688,35 @@ class DocumentQueryHelper:
finally:
os.unlink(temp_file_path)
def handle_unstructured_document(self, document: str, scheme: str) -> str:
def handle_unstructured_document(
self,
document: str,
scheme: str,
remote_resource: HttpFetchResult | None = None,
) -> str:
elements: list[Document] = []
if scheme in ["http", "https"]:
# loader = UnstructuredURLLoader(urls=[document], mode="single")
loader = UnstructuredLoader(
web_url=document,
mode="single",
partition_via_api=False,
# chunking_strategy="by_page",
strategy="hi_res",
)
elements = loader.load()
if remote_resource is None:
raise ValueError("Missing prefetched remote document content")
import tempfile
temp_file_path = ""
suffix = self._get_temp_file_suffix(document, remote_resource)
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(remote_resource.content)
temp_file_path = temp_file.name
try:
loader = UnstructuredLoader(
file_path=temp_file_path,
mode="single",
partition_via_api=False,
# chunking_strategy="by_page",
strategy="hi_res",
)
elements = loader.load()
finally:
os.unlink(temp_file_path)
elif scheme == "file":
# Use RFC file operations to read the file as binary
file_content_bytes = files.read_file_bin(document)

View file

@ -1,5 +1,164 @@
from __future__ import annotations
from dataclasses import dataclass
import ipaddress
import socket
import struct
from urllib.parse import urljoin, urlparse
import requests
SAFE_HTTP_SCHEMES = frozenset({"http", "https"})
DEFAULT_FETCH_TIMEOUT = (3.05, 10.0)
@dataclass(frozen=True)
class HttpFetchResult:
url: str
content: bytes
content_type: str | None
encoding: str | None
class UnsafeUrlError(ValueError):
"""Raised when a remote URL resolves to a non-public destination."""
def _normalize_content_type(content_type: str | None) -> str | None:
if not content_type:
return None
return content_type.split(";", 1)[0].strip().lower() or None
def resolve_host_ips(hostname: str) -> tuple[ipaddress._BaseAddress, ...]:
try:
results = socket.getaddrinfo(
hostname,
None,
family=socket.AF_UNSPEC,
type=socket.SOCK_STREAM,
)
except socket.gaierror as exc:
raise UnsafeUrlError(f"Unable to resolve hostname '{hostname}'") from exc
ips: list[ipaddress._BaseAddress] = []
seen: set[str] = set()
for _family, _type, _proto, _canonname, sockaddr in results:
address = sockaddr[0]
if "%" in address:
address = address.split("%", 1)[0]
ip = ipaddress.ip_address(address)
key = ip.compressed
if key in seen:
continue
seen.add(key)
ips.append(ip)
if not ips:
raise UnsafeUrlError(f"Hostname '{hostname}' did not resolve to an IP address")
return tuple(ips)
def validate_public_http_url(url: str) -> tuple[ipaddress._BaseAddress, ...]:
parsed = urlparse(url)
if parsed.scheme not in SAFE_HTTP_SCHEMES:
raise UnsafeUrlError("Only http:// and https:// URLs are supported")
if not parsed.hostname:
raise UnsafeUrlError("URL hostname is required")
if parsed.username or parsed.password:
raise UnsafeUrlError("URLs with embedded credentials are not allowed")
hostname = parsed.hostname.rstrip(".").lower()
if hostname == "localhost" or hostname.endswith(".localhost"):
raise UnsafeUrlError(f"Blocked local hostname '{hostname}'")
ips = resolve_host_ips(hostname)
blocked = [str(ip) for ip in ips if not ip.is_global]
if blocked:
raise UnsafeUrlError(
f"Blocked non-public address resolution for '{hostname}': {', '.join(blocked)}"
)
return ips
def fetch_public_http_resource(
url: str,
*,
max_bytes: int,
max_redirects: int = 5,
timeout: tuple[float, float] = DEFAULT_FETCH_TIMEOUT,
) -> HttpFetchResult:
current_url = url
session = requests.Session()
session.trust_env = False
for redirect_count in range(max_redirects + 1):
validate_public_http_url(current_url)
try:
with session.get(
current_url,
stream=True,
allow_redirects=False,
timeout=timeout,
) as response:
if 300 <= response.status_code < 400:
location = response.headers.get("Location")
if not location:
raise ValueError(
f"Remote URL redirect is missing a Location header: {current_url}"
)
if redirect_count >= max_redirects:
raise ValueError(
f"Remote URL exceeded redirect limit ({max_redirects}): {url}"
)
current_url = urljoin(current_url, location)
continue
if response.status_code >= 400:
raise ValueError(
f"Remote URL returned HTTP {response.status_code}: {current_url}"
)
content_length = response.headers.get("Content-Length")
if content_length:
try:
declared_length = int(content_length)
except ValueError:
declared_length = None
if declared_length is not None and declared_length > max_bytes:
raise ValueError(
f"Remote document exceeds max size {max_bytes} bytes: {current_url}"
)
body = bytearray()
for chunk in response.iter_content(chunk_size=64 * 1024):
if not chunk:
continue
body.extend(chunk)
if len(body) > max_bytes:
raise ValueError(
f"Remote document exceeds max size {max_bytes} bytes: {current_url}"
)
return HttpFetchResult(
url=current_url,
content=bytes(body),
content_type=_normalize_content_type(
response.headers.get("Content-Type")
),
encoding=response.encoding,
)
except requests.RequestException as exc:
raise ValueError(
f"Remote document fetch failed for {current_url}: {exc}"
) from exc
raise ValueError(f"Remote URL exceeded redirect limit ({max_redirects}): {url}")
def is_loopback_address(address: str) -> bool: