diff --git a/api/image_get.py b/api/image_get.py index 8c0071201..26d027e35 100644 --- a/api/image_get.py +++ b/api/image_get.py @@ -1,5 +1,6 @@ import base64 import os +from pathlib import Path from urllib.parse import quote from helpers.api import ApiHandler, Request, Response, send_file from helpers import files, runtime @@ -7,6 +8,24 @@ import io from mimetypes import guess_type +IMAGE_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".gif", + ".bmp", + ".webp", + ".svg", + ".ico", + ".svgz", +) +SVG_EXTENSIONS = (".svg", ".svgz") +SVG_CONTENT_SECURITY_POLICY = ( + "sandbox; default-src 'none'; script-src 'none'; " + "img-src 'self' data:; style-src 'unsafe-inline'" +) + + class ImageGet(ApiHandler): @classmethod @@ -16,48 +35,35 @@ class ImageGet(ApiHandler): async def process(self, input: dict, request: Request) -> dict | Response: # input data path = input.get("path", request.args.get("path", "")) - metadata = ( - input.get("metadata", request.args.get("metadata", "false")).lower() - == "true" - ) if not path: raise ValueError("No path provided") - # no real need to check, we have the extension filter in place - # check if path is within base directory - # if runtime.is_development(): - # in_base = files.is_in_base_dir(files.fix_dev_path(path)) - # else: - # in_base = files.is_in_base_dir(path) - # if not in_base and not files.is_in_dir(path, "/root"): - # raise ValueError("Path is outside of allowed directory") - # get file extension and info file_ext = os.path.splitext(path)[1].lower() filename = os.path.basename(path) - # list of allowed image extensions - image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".svg", ".ico", ".svgz"] - - # # If metadata is requested, return file information - # if metadata: - # return _get_file_metadata(path, filename, file_ext, image_extensions) - - if file_ext in image_extensions: + if file_ext in IMAGE_EXTENSIONS: + try: + local_path = _resolve_allowed_image_path(path) + except ValueError as exc: + return Response(str(exc), status=403, mimetype="text/plain") # in development environment, try to serve the image from local file system if exists, otherwise from docker if runtime.is_development(): - # Convert /a0/... Docker paths to local absolute paths - local_path = files.fix_dev_path(path) if files.exists(local_path): response = send_file(local_path) else: # Try fetching from Docker via RFC as fallback try: - if await runtime.call_development_function(files.exists, path): + remote_path = await runtime.call_development_function( + _resolve_allowed_image_path, path + ) + if await runtime.call_development_function( + files.exists, remote_path + ): b64_content = await runtime.call_development_function( - files.read_file_base64, path + files.read_file_base64, remote_path ) file_content = base64.b64decode(b64_content) mime_type, _ = guess_type(filename) @@ -74,21 +80,50 @@ class ImageGet(ApiHandler): except Exception: response = _send_fallback_icon("image") else: - if files.exists(path): - response = send_file(path) + if files.exists(local_path): + response = send_file(local_path) else: response = _send_fallback_icon("image") - # Add cache headers for better device sync performance - response.headers["Cache-Control"] = "public, max-age=3600" - response.headers["X-File-Type"] = "image" - response.headers["X-File-Name"] = quote(filename) + _set_image_headers(response, filename, file_ext) return response else: # Handle non-image files with fallback icons return _send_file_type_icon(file_ext, filename) +def _resolve_allowed_image_path(path: str) -> str: + """Resolve a requested image path and keep it inside Agent Zero's base dir.""" + + if runtime.is_development(): + candidate = Path(files.fix_dev_path(path)) + else: + candidate = Path(files.get_abs_path(path)) + + if not candidate.is_absolute(): + candidate = Path(files.get_base_dir()) / candidate + + base_dir = Path(files.get_base_dir()).resolve() + resolved = candidate.resolve(strict=False) + + try: + resolved.relative_to(base_dir) + except ValueError as exc: + raise ValueError("Path is outside of allowed directory") from exc + + return str(resolved) + + +def _set_image_headers(response: Response, filename: str, file_ext: str) -> None: + # Add cache headers for better device sync performance. + response.headers["Cache-Control"] = "public, max-age=3600" + response.headers["X-File-Type"] = "image" + response.headers["X-File-Name"] = quote(filename) + response.headers["X-Content-Type-Options"] = "nosniff" + if file_ext in SVG_EXTENSIONS: + response.headers["Content-Security-Policy"] = SVG_CONTENT_SECURITY_POLICY + + def _send_file_type_icon(file_ext, filename=None): """Return appropriate icon for file type""" diff --git a/helpers/api.py b/helpers/api.py index 8d4f64b08..9616528b7 100644 --- a/helpers/api.py +++ b/helpers/api.py @@ -16,7 +16,6 @@ from flask import ( url_for, ) from werkzeug.wrappers.response import Response as BaseResponse -from agent import AgentContext from helpers.print_style import PrintStyle from helpers.errors import format_error from helpers import files, cache diff --git a/helpers/runtime.py b/helpers/runtime.py index d56ca44fa..1d4b084e1 100644 --- a/helpers/runtime.py +++ b/helpers/runtime.py @@ -3,7 +3,7 @@ import inspect import secrets from pathlib import Path from typing import TypeVar, Callable, Awaitable, Union, overload, cast -from helpers import dotenv, rfc, settings, files +from helpers import dotenv, rfc, files import asyncio import threading import queue @@ -134,6 +134,8 @@ def _get_rfc_password() -> str: def _get_rfc_url() -> str: + # Delay import to avoid a circular import with helpers.settings. + from helpers import settings set = settings.get_settings() url = set["rfc_url"] if not "://" in url: diff --git a/tests/test_image_get_security.py b/tests/test_image_get_security.py new file mode 100644 index 000000000..88ca2d0be --- /dev/null +++ b/tests/test_image_get_security.py @@ -0,0 +1,151 @@ +import asyncio +import base64 +import sys +import threading +from pathlib import Path + +from flask import Flask, request + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from api import image_get + + +def _patch_base_dir(monkeypatch, base_dir: Path, *, development: bool = False) -> None: + base_dir.mkdir(parents=True, exist_ok=True) + + def fake_get_abs_path(*parts: str) -> str: + if len(parts) == 1 and Path(str(parts[0])).is_absolute(): + return str(Path(str(parts[0]))) + return str(base_dir.joinpath(*(str(part) for part in parts))) + + monkeypatch.setattr(image_get.files, "get_base_dir", lambda: str(base_dir)) + monkeypatch.setattr(image_get.files, "get_abs_path", fake_get_abs_path) + monkeypatch.setattr(image_get.runtime, "is_development", lambda: development) + + +async def _request_image(path: str): + app = Flask("test_image_get_security") + handler = image_get.ImageGet(app, threading.Lock()) + with app.test_request_context("/api/image_get"): + return await handler.process({"path": path}, request) + + +def test_image_get_serves_images_inside_base_dir(tmp_path, monkeypatch): + base_dir = tmp_path / "a0" + _patch_base_dir(monkeypatch, base_dir) + image_path = base_dir / "usr" / "uploads" / "safe.png" + image_path.parent.mkdir(parents=True) + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + + response = asyncio.run(_request_image(str(image_path))) + + assert response.status_code == 200 + assert response.headers["X-File-Type"] == "image" + assert response.headers["X-Content-Type-Options"] == "nosniff" + + +def test_image_get_blocks_image_paths_outside_base_dir(tmp_path, monkeypatch): + base_dir = tmp_path / "a0" + _patch_base_dir(monkeypatch, base_dir) + outside_image = tmp_path / "outside.png" + outside_image.write_bytes(b"outside") + + response = asyncio.run(_request_image(str(outside_image))) + + assert response.status_code == 403 + assert response.get_data(as_text=True) == "Path is outside of allowed directory" + + +def test_image_get_blocks_symlink_escape_from_base_dir(tmp_path, monkeypatch): + base_dir = tmp_path / "a0" + _patch_base_dir(monkeypatch, base_dir) + outside_image = tmp_path / "secret.png" + outside_image.write_bytes(b"secret") + link_path = base_dir / "usr" / "uploads" / "linked.png" + link_path.parent.mkdir(parents=True) + link_path.symlink_to(outside_image) + + response = asyncio.run(_request_image(str(link_path))) + + assert response.status_code == 403 + + +def test_image_get_hardens_svg_responses(tmp_path, monkeypatch): + base_dir = tmp_path / "a0" + _patch_base_dir(monkeypatch, base_dir) + svg_path = base_dir / "usr" / "uploads" / "payload.svg" + svg_path.parent.mkdir(parents=True) + svg_path.write_text( + '', + encoding="utf-8", + ) + + response = asyncio.run(_request_image(str(svg_path))) + + assert response.status_code == 200 + assert response.headers["Content-Security-Policy"].startswith("sandbox;") + assert "script-src 'none'" in response.headers["Content-Security-Policy"] + assert response.headers["X-Content-Type-Options"] == "nosniff" + + +def test_image_get_development_fallback_validates_remote_path(tmp_path, monkeypatch): + base_dir = tmp_path / "a0" + _patch_base_dir(monkeypatch, base_dir, development=True) + calls = [] + + async def fake_call_development_function(func, *args, **kwargs): + calls.append(func.__name__) + if func is image_get._resolve_allowed_image_path: + return "/a0/usr/uploads/remote.png" + if func is image_get.files.exists: + return True + if func is image_get.files.read_file_base64: + return base64.b64encode(b"\x89PNG\r\n\x1a\n").decode("ascii") + raise AssertionError(f"Unexpected remote call: {func.__name__}") + + monkeypatch.setattr( + image_get.runtime, + "call_development_function", + fake_call_development_function, + ) + + response = asyncio.run(_request_image("/a0/usr/uploads/remote.png")) + + assert response.status_code == 200 + assert response.headers["X-File-Type"] == "image" + assert calls == ["_resolve_allowed_image_path", "exists", "read_file_base64"] + + +def test_image_get_development_fallback_does_not_read_rejected_remote_path( + tmp_path, + monkeypatch, +): + base_dir = tmp_path / "a0" + _patch_base_dir(monkeypatch, base_dir, development=True) + calls = [] + + async def fake_call_development_function(func, *args, **kwargs): + calls.append(func.__name__) + if func is image_get._resolve_allowed_image_path: + raise ValueError("Path is outside of allowed directory") + raise AssertionError(f"Unexpected remote call after validation: {func.__name__}") + + monkeypatch.setattr( + image_get.runtime, + "call_development_function", + fake_call_development_function, + ) + monkeypatch.setattr( + image_get, + "_send_fallback_icon", + lambda _icon_name: image_get.Response("fallback", status=200), + ) + + response = asyncio.run(_request_image("/a0/usr/uploads/rejected.png")) + + assert response.status_code == 200 + assert response.get_data(as_text=True) == "fallback" + assert calls == ["_resolve_allowed_image_path"]