mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-04-28 03:30:23 +00:00
Fix SSRF in document_query remote fetching (CVE-2026-4308)
Address CVE-2026-4308 in the document_query tool remote-fetch path. The issue was originally reported by @YLChen-007. This change replaces ad hoc remote document fetching with a centralized safe fetch flow that validates remote URLs before any network request is used for parsing. It blocks localhost and non-public IPv4/IPv6 targets, validates every redirect hop, disables implicit trust of proxy env settings for this path, and enforces a strict remote document size cap. It also removes direct third-party loader access to attacker-controlled URLs by prefetching remote content first and then parsing only trusted local bytes or temp files for HTML, text, PDF, image, and unstructured document handling. Refs: - CVE-2026-4308 - Report by @YLChen-007
This commit is contained in:
parent
071194281c
commit
6397acc092
2 changed files with 279 additions and 73 deletions
|
|
@ -1,7 +1,6 @@
|
|||
import mimetypes
|
||||
import os
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
from helpers.vector_db import VectorDB
|
||||
|
|
@ -13,8 +12,6 @@ from urllib.parse import urlparse
|
|||
from typing import Callable, Sequence, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
from langchain_community.document_loaders.text import TextLoader
|
||||
from langchain_community.document_loaders.pdf import PyMuPDFLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
from langchain_community.document_loaders.parsers.images import TesseractBlobParser
|
||||
|
|
@ -24,12 +21,14 @@ from langchain.schema import SystemMessage, HumanMessage
|
|||
|
||||
from helpers.print_style import PrintStyle
|
||||
from helpers import files, errors
|
||||
from helpers.network import HttpFetchResult, fetch_public_http_resource
|
||||
from agent import Agent
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
DEFAULT_SEARCH_THRESHOLD = 0.5
|
||||
MAX_REMOTE_DOCUMENT_BYTES = 50 * 1024 * 1024
|
||||
|
||||
|
||||
class DocumentQueryStore:
|
||||
|
|
@ -450,45 +449,19 @@ class DocumentQueryHelper:
|
|||
scheme = url.scheme or "file"
|
||||
mimetype, encoding = mimetypes.guess_type(document_uri)
|
||||
mimetype = mimetype or "application/octet-stream"
|
||||
remote_resource: HttpFetchResult | None = None
|
||||
|
||||
if mimetype == "application/octet-stream":
|
||||
if url.scheme in ["http", "https"]:
|
||||
response: aiohttp.ClientResponse | None = None
|
||||
retries = 0
|
||||
last_error = ""
|
||||
while not response and retries < 3:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response = await session.head(
|
||||
document_uri,
|
||||
timeout=aiohttp.ClientTimeout(total=2.0),
|
||||
allow_redirects=True,
|
||||
)
|
||||
if response.status > 399:
|
||||
raise Exception(response.status)
|
||||
break
|
||||
except Exception as e:
|
||||
await asyncio.sleep(1)
|
||||
last_error = str(e)
|
||||
retries += 1
|
||||
await self.agent.handle_intervention()
|
||||
|
||||
if not response:
|
||||
raise ValueError(
|
||||
f"DocumentQueryHelper::document_get_content: Document fetch error: {document_uri} ({last_error})"
|
||||
)
|
||||
|
||||
mimetype = response.headers["content-type"]
|
||||
if "content-length" in response.headers:
|
||||
content_length = (
|
||||
float(response.headers["content-length"]) / 1024 / 1024
|
||||
) # MB
|
||||
if content_length > 50.0:
|
||||
raise ValueError(
|
||||
f"Document content length exceeds max. 50MB: {content_length} MB ({document_uri})"
|
||||
)
|
||||
if mimetype and "; charset=" in mimetype:
|
||||
mimetype = mimetype.split("; charset=")[0]
|
||||
if scheme in ["http", "https"]:
|
||||
remote_resource = await asyncio.to_thread(
|
||||
fetch_public_http_resource,
|
||||
document_uri,
|
||||
max_bytes=MAX_REMOTE_DOCUMENT_BYTES,
|
||||
)
|
||||
if (
|
||||
remote_resource.content_type
|
||||
and remote_resource.content_type != "application/octet-stream"
|
||||
):
|
||||
mimetype = remote_resource.content_type
|
||||
|
||||
if scheme == "file":
|
||||
try:
|
||||
|
|
@ -515,16 +488,24 @@ class DocumentQueryHelper:
|
|||
if not exists:
|
||||
await self.agent.handle_intervention()
|
||||
if mimetype.startswith("image/"):
|
||||
document_content = self.handle_image_document(document_uri, scheme)
|
||||
document_content = self.handle_image_document(
|
||||
document_uri, scheme, remote_resource=remote_resource
|
||||
)
|
||||
elif mimetype == "text/html":
|
||||
document_content = self.handle_html_document(document_uri, scheme)
|
||||
document_content = self.handle_html_document(
|
||||
document_uri, scheme, remote_resource=remote_resource
|
||||
)
|
||||
elif mimetype.startswith("text/") or mimetype == "application/json":
|
||||
document_content = self.handle_text_document(document_uri, scheme)
|
||||
document_content = self.handle_text_document(
|
||||
document_uri, scheme, remote_resource=remote_resource
|
||||
)
|
||||
elif mimetype == "application/pdf":
|
||||
document_content = self.handle_pdf_document(document_uri, scheme)
|
||||
document_content = self.handle_pdf_document(
|
||||
document_uri, scheme, remote_resource=remote_resource
|
||||
)
|
||||
else:
|
||||
document_content = self.handle_unstructured_document(
|
||||
document_uri, scheme
|
||||
document_uri, scheme, remote_resource=remote_resource
|
||||
)
|
||||
if add_to_db:
|
||||
self.progress_callback(f"Indexing document")
|
||||
|
|
@ -550,13 +531,53 @@ class DocumentQueryHelper:
|
|||
)
|
||||
return document_content
|
||||
|
||||
def handle_image_document(self, document: str, scheme: str) -> str:
|
||||
return self.handle_unstructured_document(document, scheme)
|
||||
@staticmethod
|
||||
def _decode_remote_text(remote_resource: HttpFetchResult) -> str:
|
||||
encoding = remote_resource.encoding or "utf-8"
|
||||
try:
|
||||
return remote_resource.content.decode(encoding)
|
||||
except (LookupError, UnicodeDecodeError):
|
||||
return remote_resource.content.decode("utf-8", errors="replace")
|
||||
|
||||
def handle_html_document(self, document: str, scheme: str) -> str:
|
||||
@staticmethod
|
||||
def _get_temp_file_suffix(
|
||||
document: str, remote_resource: HttpFetchResult | None = None
|
||||
) -> str:
|
||||
parsed = urlparse(document)
|
||||
_stem, ext = os.path.splitext(parsed.path or document)
|
||||
if ext:
|
||||
return ext
|
||||
|
||||
if remote_resource and remote_resource.content_type:
|
||||
guessed_ext = mimetypes.guess_extension(
|
||||
remote_resource.content_type, strict=False
|
||||
)
|
||||
if guessed_ext:
|
||||
return guessed_ext
|
||||
|
||||
return ".bin"
|
||||
|
||||
def handle_image_document(
|
||||
self,
|
||||
document: str,
|
||||
scheme: str,
|
||||
remote_resource: HttpFetchResult | None = None,
|
||||
) -> str:
|
||||
return self.handle_unstructured_document(
|
||||
document, scheme, remote_resource=remote_resource
|
||||
)
|
||||
|
||||
def handle_html_document(
|
||||
self,
|
||||
document: str,
|
||||
scheme: str,
|
||||
remote_resource: HttpFetchResult | None = None,
|
||||
) -> str:
|
||||
if scheme in ["http", "https"]:
|
||||
loader = AsyncHtmlLoader(web_path=document)
|
||||
parts: list[Document] = loader.load()
|
||||
if remote_resource is None:
|
||||
raise ValueError("Missing prefetched remote HTML content")
|
||||
html_content = self._decode_remote_text(remote_resource)
|
||||
parts = [Document(page_content=html_content, metadata={"source": document})]
|
||||
elif scheme == "file":
|
||||
# Use RFC file operations instead of TextLoader
|
||||
file_content_bytes = files.read_file_bin(document)
|
||||
|
|
@ -573,10 +594,19 @@ class DocumentQueryHelper:
|
|||
]
|
||||
)
|
||||
|
||||
def handle_text_document(self, document: str, scheme: str) -> str:
|
||||
def handle_text_document(
|
||||
self,
|
||||
document: str,
|
||||
scheme: str,
|
||||
remote_resource: HttpFetchResult | None = None,
|
||||
) -> str:
|
||||
if scheme in ["http", "https"]:
|
||||
loader = AsyncHtmlLoader(web_path=document)
|
||||
elements: list[Document] = loader.load()
|
||||
if remote_resource is None:
|
||||
raise ValueError("Missing prefetched remote text content")
|
||||
file_content = self._decode_remote_text(remote_resource)
|
||||
elements = [
|
||||
Document(page_content=file_content, metadata={"source": document})
|
||||
]
|
||||
elif scheme == "file":
|
||||
# Use RFC file operations instead of TextLoader
|
||||
file_content_bytes = files.read_file_bin(document)
|
||||
|
|
@ -590,7 +620,12 @@ class DocumentQueryHelper:
|
|||
|
||||
return "\n".join([element.page_content for element in elements])
|
||||
|
||||
def handle_pdf_document(self, document: str, scheme: str) -> str:
|
||||
def handle_pdf_document(
|
||||
self,
|
||||
document: str,
|
||||
scheme: str,
|
||||
remote_resource: HttpFetchResult | None = None,
|
||||
) -> str:
|
||||
temp_file_path = ""
|
||||
if scheme == "file":
|
||||
# Use RFC file operations to read the PDF file as binary
|
||||
|
|
@ -602,17 +637,12 @@ class DocumentQueryHelper:
|
|||
temp_file.write(file_content_bytes)
|
||||
temp_file_path = temp_file.name
|
||||
elif scheme in ["http", "https"]:
|
||||
# download the file from the web url to a temporary file using python libraries for downloading
|
||||
import requests
|
||||
import tempfile
|
||||
|
||||
if remote_resource is None:
|
||||
raise ValueError("Missing prefetched remote PDF content")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
||||
response = requests.get(document, timeout=10.0)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"DocumentQueryHelper::handle_pdf_document: Failed to download PDF from {document}: {response.status_code}"
|
||||
)
|
||||
temp_file.write(response.content)
|
||||
temp_file.write(remote_resource.content)
|
||||
temp_file_path = temp_file.name
|
||||
else:
|
||||
raise ValueError(f"Unsupported scheme: {scheme}")
|
||||
|
|
@ -658,18 +688,35 @@ class DocumentQueryHelper:
|
|||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
def handle_unstructured_document(self, document: str, scheme: str) -> str:
|
||||
def handle_unstructured_document(
|
||||
self,
|
||||
document: str,
|
||||
scheme: str,
|
||||
remote_resource: HttpFetchResult | None = None,
|
||||
) -> str:
|
||||
elements: list[Document] = []
|
||||
if scheme in ["http", "https"]:
|
||||
# loader = UnstructuredURLLoader(urls=[document], mode="single")
|
||||
loader = UnstructuredLoader(
|
||||
web_url=document,
|
||||
mode="single",
|
||||
partition_via_api=False,
|
||||
# chunking_strategy="by_page",
|
||||
strategy="hi_res",
|
||||
)
|
||||
elements = loader.load()
|
||||
if remote_resource is None:
|
||||
raise ValueError("Missing prefetched remote document content")
|
||||
import tempfile
|
||||
|
||||
temp_file_path = ""
|
||||
suffix = self._get_temp_file_suffix(document, remote_resource)
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
temp_file.write(remote_resource.content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
loader = UnstructuredLoader(
|
||||
file_path=temp_file_path,
|
||||
mode="single",
|
||||
partition_via_api=False,
|
||||
# chunking_strategy="by_page",
|
||||
strategy="hi_res",
|
||||
)
|
||||
elements = loader.load()
|
||||
finally:
|
||||
os.unlink(temp_file_path)
|
||||
elif scheme == "file":
|
||||
# Use RFC file operations to read the file as binary
|
||||
file_content_bytes = files.read_file_bin(document)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,164 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
SAFE_HTTP_SCHEMES = frozenset({"http", "https"})
|
||||
DEFAULT_FETCH_TIMEOUT = (3.05, 10.0)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HttpFetchResult:
|
||||
url: str
|
||||
content: bytes
|
||||
content_type: str | None
|
||||
encoding: str | None
|
||||
|
||||
|
||||
class UnsafeUrlError(ValueError):
|
||||
"""Raised when a remote URL resolves to a non-public destination."""
|
||||
|
||||
|
||||
def _normalize_content_type(content_type: str | None) -> str | None:
|
||||
if not content_type:
|
||||
return None
|
||||
return content_type.split(";", 1)[0].strip().lower() or None
|
||||
|
||||
|
||||
def resolve_host_ips(hostname: str) -> tuple[ipaddress._BaseAddress, ...]:
|
||||
try:
|
||||
results = socket.getaddrinfo(
|
||||
hostname,
|
||||
None,
|
||||
family=socket.AF_UNSPEC,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeUrlError(f"Unable to resolve hostname '{hostname}'") from exc
|
||||
|
||||
ips: list[ipaddress._BaseAddress] = []
|
||||
seen: set[str] = set()
|
||||
for _family, _type, _proto, _canonname, sockaddr in results:
|
||||
address = sockaddr[0]
|
||||
if "%" in address:
|
||||
address = address.split("%", 1)[0]
|
||||
ip = ipaddress.ip_address(address)
|
||||
key = ip.compressed
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
ips.append(ip)
|
||||
|
||||
if not ips:
|
||||
raise UnsafeUrlError(f"Hostname '{hostname}' did not resolve to an IP address")
|
||||
|
||||
return tuple(ips)
|
||||
|
||||
|
||||
def validate_public_http_url(url: str) -> tuple[ipaddress._BaseAddress, ...]:
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.scheme not in SAFE_HTTP_SCHEMES:
|
||||
raise UnsafeUrlError("Only http:// and https:// URLs are supported")
|
||||
if not parsed.hostname:
|
||||
raise UnsafeUrlError("URL hostname is required")
|
||||
if parsed.username or parsed.password:
|
||||
raise UnsafeUrlError("URLs with embedded credentials are not allowed")
|
||||
|
||||
hostname = parsed.hostname.rstrip(".").lower()
|
||||
if hostname == "localhost" or hostname.endswith(".localhost"):
|
||||
raise UnsafeUrlError(f"Blocked local hostname '{hostname}'")
|
||||
|
||||
ips = resolve_host_ips(hostname)
|
||||
blocked = [str(ip) for ip in ips if not ip.is_global]
|
||||
if blocked:
|
||||
raise UnsafeUrlError(
|
||||
f"Blocked non-public address resolution for '{hostname}': {', '.join(blocked)}"
|
||||
)
|
||||
|
||||
return ips
|
||||
|
||||
|
||||
def fetch_public_http_resource(
|
||||
url: str,
|
||||
*,
|
||||
max_bytes: int,
|
||||
max_redirects: int = 5,
|
||||
timeout: tuple[float, float] = DEFAULT_FETCH_TIMEOUT,
|
||||
) -> HttpFetchResult:
|
||||
current_url = url
|
||||
session = requests.Session()
|
||||
session.trust_env = False
|
||||
|
||||
for redirect_count in range(max_redirects + 1):
|
||||
validate_public_http_url(current_url)
|
||||
|
||||
try:
|
||||
with session.get(
|
||||
current_url,
|
||||
stream=True,
|
||||
allow_redirects=False,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
if 300 <= response.status_code < 400:
|
||||
location = response.headers.get("Location")
|
||||
if not location:
|
||||
raise ValueError(
|
||||
f"Remote URL redirect is missing a Location header: {current_url}"
|
||||
)
|
||||
if redirect_count >= max_redirects:
|
||||
raise ValueError(
|
||||
f"Remote URL exceeded redirect limit ({max_redirects}): {url}"
|
||||
)
|
||||
current_url = urljoin(current_url, location)
|
||||
continue
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise ValueError(
|
||||
f"Remote URL returned HTTP {response.status_code}: {current_url}"
|
||||
)
|
||||
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length:
|
||||
try:
|
||||
declared_length = int(content_length)
|
||||
except ValueError:
|
||||
declared_length = None
|
||||
if declared_length is not None and declared_length > max_bytes:
|
||||
raise ValueError(
|
||||
f"Remote document exceeds max size {max_bytes} bytes: {current_url}"
|
||||
)
|
||||
|
||||
body = bytearray()
|
||||
for chunk in response.iter_content(chunk_size=64 * 1024):
|
||||
if not chunk:
|
||||
continue
|
||||
body.extend(chunk)
|
||||
if len(body) > max_bytes:
|
||||
raise ValueError(
|
||||
f"Remote document exceeds max size {max_bytes} bytes: {current_url}"
|
||||
)
|
||||
|
||||
return HttpFetchResult(
|
||||
url=current_url,
|
||||
content=bytes(body),
|
||||
content_type=_normalize_content_type(
|
||||
response.headers.get("Content-Type")
|
||||
),
|
||||
encoding=response.encoding,
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
raise ValueError(
|
||||
f"Remote document fetch failed for {current_url}: {exc}"
|
||||
) from exc
|
||||
|
||||
raise ValueError(f"Remote URL exceeded redirect limit ({max_redirects}): {url}")
|
||||
|
||||
|
||||
def is_loopback_address(address: str) -> bool:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue