Major refactor: API, providers, messaging, and Anthropic protocol
Some checks are pending
CI / checks (push) Waiting to run

Consolidates the incremental refactor work into a single change set: modular web tools (api/web_tools), native Anthropic request building and SSE block policy, OpenAI conversion and error handling, provider transports and rate limiting, messaging handler and tree queue, safe logging, smoke tests, and broad test coverage.
This commit is contained in:
Alishahryar1 2026-04-26 02:55:10 -07:00
parent b9ed704095
commit f3a7528d49
139 changed files with 7460 additions and 2422 deletions

View file

@ -1,6 +1,6 @@
"""API layer for Claude Code Proxy."""
from .app import app, create_app
from .app import create_app
from .models import (
MessagesRequest,
MessagesResponse,
@ -13,6 +13,5 @@ __all__ = [
"MessagesResponse",
"TokenCountRequest",
"TokenCountResponse",
"app",
"create_app",
]

View file

@ -1,6 +1,6 @@
"""FastAPI application factory and configuration."""
import os
import traceback
from contextlib import asynccontextmanager
from typing import Any
@ -16,13 +16,7 @@ from providers.exceptions import ProviderError
from .routes import router
from .runtime import AppRuntime
# Opt-in to future behavior for python-telegram-bot
os.environ["PTB_TIMEDELTA"] = "1"
# Configure logging first (before any module logs)
_settings = get_settings()
configure_logging(_settings.log_file)
from .validation_log import summarize_request_validation_body
@asynccontextmanager
@ -38,6 +32,11 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
settings = get_settings()
configure_logging(
settings.log_file, verbose_third_party=settings.log_raw_api_payloads
)
app = FastAPI(
title="Claude Code Proxy",
version="2.0.0",
@ -57,33 +56,7 @@ def create_app() -> FastAPI:
except Exception as e:
body = {"_json_error": type(e).__name__}
messages = body.get("messages") if isinstance(body, dict) else None
message_summary: list[dict[str, Any]] = []
if isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
message_summary.append({"message_kind": type(msg).__name__})
continue
content = msg.get("content")
item: dict[str, Any] = {
"role": msg.get("role"),
"content_kind": type(content).__name__,
}
if isinstance(content, list):
item["block_types"] = [
block.get("type", "dict")
if isinstance(block, dict)
else type(block).__name__
for block in content[:12]
]
item["block_keys"] = [
sorted(str(key) for key in block)[:12]
for block in content[:5]
if isinstance(block, dict)
]
elif isinstance(content, str):
item["content_length"] = len(content)
message_summary.append(item)
message_summary, tool_names = summarize_request_validation_body(body)
logger.debug(
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
@ -92,20 +65,27 @@ def create_app() -> FastAPI:
[list(error.get("loc", ())) for error in exc.errors()],
[str(error.get("type", "")) for error in exc.errors()],
message_summary,
[
str(tool.get("name", ""))
for tool in body.get("tools", [])
if isinstance(body, dict)
and isinstance(body.get("tools"), list)
and isinstance(tool, dict)
],
tool_names,
)
return await request_validation_exception_handler(request, exc)
@app.exception_handler(ProviderError)
async def provider_error_handler(request: Request, exc: ProviderError):
"""Handle provider-specific errors and return Anthropic format."""
logger.error(f"Provider Error: {exc.error_type} - {exc.message}")
err_settings = get_settings()
if err_settings.log_api_error_tracebacks:
logger.error(
"Provider Error: error_type={} status_code={} message={}",
exc.error_type,
exc.status_code,
exc.message,
)
else:
logger.error(
"Provider Error: error_type={} status_code={}",
exc.error_type,
exc.status_code,
)
return JSONResponse(
status_code=exc.status_code,
content=exc.to_anthropic_format(),
@ -114,10 +94,17 @@ def create_app() -> FastAPI:
@app.exception_handler(Exception)
async def general_error_handler(request: Request, exc: Exception):
"""Handle general errors and return Anthropic format."""
logger.error(f"General Error: {exc!s}")
import traceback
logger.error(traceback.format_exc())
settings = get_settings()
if settings.log_api_error_tracebacks:
logger.error("General Error: {}", exc)
logger.error(traceback.format_exc())
else:
logger.error(
"General Error: path={} method={} exc_type={}",
request.url.path,
request.method,
type(exc).__name__,
)
return JSONResponse(
status_code=500,
content={
@ -130,7 +117,3 @@ def create_app() -> FastAPI:
)
return app
# Default app instance for uvicorn
app = create_app()

View file

@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str:
return "<filepaths>\n</filepaths>"
except Exception:
except ValueError:
return "<filepaths>\n</filepaths>"

View file

@ -8,7 +8,11 @@ from config.settings import Settings
from config.settings import get_settings as _get_settings
from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
from providers.exceptions import (
AuthenticationError,
ServiceUnavailableError,
UnknownProviderTypeError,
)
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
"""Get application settings via dependency injection."""
"""Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
return _get_settings()
@ -31,10 +35,9 @@ def resolve_provider(
"""Resolve a provider using the app-scoped registry when ``app`` is set.
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
is always used. If the registry is missing (e.g. a test app without
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
is installed on ``app.state`` so the process cache is never mixed with
per-request app identity.
must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
Callers that construct a bare ``FastAPI`` without lifespan must set
``app.state.provider_registry`` explicitly.
When ``app`` is ``None`` (no HTTP context), uses the process-level
:data:`_providers` cache only.
@ -42,8 +45,10 @@ def resolve_provider(
if app is not None:
reg = getattr(app.state, "provider_registry", None)
if reg is None:
reg = ProviderRegistry()
app.state.provider_registry = reg
raise ServiceUnavailableError(
"Provider registry is not configured. Ensure AppRuntime startup ran "
"or assign app.state.provider_registry for test apps."
)
return _resolve_with_registry(reg, provider_type, settings)
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
@ -55,9 +60,10 @@ def _resolve_with_registry(
try:
provider = registry.get(provider_type, settings)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
# Provider :class:`~providers.exceptions.AuthenticationError` messages are
# curated configuration hints (env var names, docs links), not upstream noise.
detail = str(e).strip() or get_user_facing_error_message(e)
raise HTTPException(status_code=503, detail=detail) from e
except UnknownProviderTypeError:
logger.error(
"Unknown provider_type: '{}'. Supported: {}",
@ -73,8 +79,9 @@ def _resolve_with_registry(
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider in the process-level cache (no ``app``/Request).
For server requests, use :func:`resolve_provider` with the active
:attr:`request.app` so the app-scoped provider registry is used.
HTTP route handlers should call :func:`resolve_provider` with the active
:attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
process-wide cache.
"""
return resolve_provider(provider_type, app=None, settings=get_settings())

View file

@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st
try:
cmd_start = content.rfind("Command:") + len("Command:")
return True, content[cmd_start:].strip()
except Exception:
pass
except TypeError:
return False, ""
return False, ""
@ -121,19 +121,16 @@ def is_filepath_extraction_request(
if not user_has_filepaths and not system_has_extract:
return False, "", ""
try:
cmd_start = content.find("Command:") + len("Command:")
output_marker = content.find("Output:", cmd_start)
if output_marker == -1:
return False, "", ""
command = content[cmd_start:output_marker].strip()
output = content[output_marker + len("Output:") :].strip()
for marker in ["<", "\n\n"]:
if marker in output:
output = output.split(marker)[0].strip()
return True, command, output
except Exception:
cmd_start = content.find("Command:") + len("Command:")
output_marker = content.find("Output:", cmd_start)
if output_marker == -1:
return False, "", ""
command = content[cmd_start:output_marker].strip()
output = content[output_marker + len("Output:") :].strip()
for marker in ["<", "\n\n"]:
if marker in output:
output = output.split(marker)[0].strip()
return True, command, output

View file

@ -3,7 +3,7 @@
from enum import StrEnum
from typing import Any, Literal
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field
# =============================================================================
@ -15,41 +15,68 @@ class Role(StrEnum):
system = "system"
class ContentBlockText(BaseModel):
class _AnthropicBlockBase(BaseModel):
"""Pass through provider fields (e.g. ``cache_control``) for native transports."""
model_config = ConfigDict(extra="allow")
class ContentBlockText(_AnthropicBlockBase):
type: Literal["text"]
text: str
class ContentBlockImage(BaseModel):
class ContentBlockImage(_AnthropicBlockBase):
type: Literal["image"]
source: dict[str, Any]
class ContentBlockToolUse(BaseModel):
class ContentBlockToolUse(_AnthropicBlockBase):
type: Literal["tool_use"]
id: str
name: str
input: dict[str, Any]
class ContentBlockToolResult(BaseModel):
class ContentBlockToolResult(_AnthropicBlockBase):
type: Literal["tool_result"]
tool_use_id: str
content: str | list[Any] | dict[str, Any]
class ContentBlockThinking(BaseModel):
class ContentBlockThinking(_AnthropicBlockBase):
type: Literal["thinking"]
thinking: str
signature: str | None = None
class ContentBlockRedactedThinking(BaseModel):
class ContentBlockRedactedThinking(_AnthropicBlockBase):
type: Literal["redacted_thinking"]
data: str
class SystemContent(BaseModel):
class ContentBlockServerToolUse(_AnthropicBlockBase):
"""Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
type: Literal["server_tool_use"]
id: str
name: str
input: dict[str, Any]
class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
type: Literal["web_search_tool_result"]
tool_use_id: str
content: Any
class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
type: Literal["web_fetch_tool_result"]
tool_use_id: str
content: Any
class SystemContent(_AnthropicBlockBase):
type: Literal["text"]
text: str
@ -68,12 +95,15 @@ class Message(BaseModel):
| ContentBlockToolResult
| ContentBlockThinking
| ContentBlockRedactedThinking
| ContentBlockServerToolUse
| ContentBlockWebSearchToolResult
| ContentBlockWebFetchToolResult
]
)
reasoning_content: str | None = None
class Tool(BaseModel):
class Tool(_AnthropicBlockBase):
name: str
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
# may omit ``input_schema`` because the provider owns the schema.
@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel):
# Request Models
# =============================================================================
class MessagesRequest(BaseModel):
model_config = ConfigDict(extra="allow")
model: str
# Internal routing / debug: accepted on parse but not serialized to providers.
original_model: str | None = Field(default=None, exclude=True)
resolved_provider_model: str | None = Field(default=None, exclude=True)
max_tokens: int | None = None
messages: list[Message]
system: str | list[SystemContent] | None = None
@ -105,13 +140,24 @@ class MessagesRequest(BaseModel):
tools: list[Tool] | None = None
tool_choice: dict[str, Any] | None = None
thinking: ThinkingConfig | None = None
# Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
context_management: dict[str, Any] | None = None
output_config: dict[str, Any] | None = None
mcp_servers: list[dict[str, Any]] | None = None
extra_body: dict[str, Any] | None = None
class TokenCountRequest(BaseModel):
model_config = ConfigDict(extra="allow")
model: str
original_model: str | None = Field(default=None, exclude=True)
resolved_provider_model: str | None = Field(default=None, exclude=True)
messages: list[Message]
system: str | list[SystemContent] | None = None
tools: list[Tool] | None = None
thinking: ThinkingConfig | None = None
tool_choice: dict[str, Any] | None = None
context_management: dict[str, Any] | None = None
output_config: dict[str, Any] | None = None
mcp_servers: list[dict[str, Any]] | None = None

View file

@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0
async def best_effort(
name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S
name: str,
awaitable: Any,
timeout_s: float = _SHUTDOWN_TIMEOUT_S,
*,
log_verbose_errors: bool = False,
) -> None:
"""Run a shutdown step with timeout; never raise to callers."""
try:
await asyncio.wait_for(awaitable, timeout=timeout_s)
except TimeoutError:
logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
except Exception as e:
logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
if log_verbose_errors:
logger.warning(
"Shutdown step failed: {}: {}: {}",
name,
type(e).__name__,
e,
)
else:
logger.warning(
"Shutdown step failed: {}: exc_type={}",
name,
type(e).__name__,
)
def warn_if_process_auth_token(settings: Settings) -> None:
@ -73,20 +89,37 @@ class AppRuntime:
self._publish_state()
async def shutdown(self) -> None:
verbose = self.settings.log_api_error_tracebacks
if self.message_handler is not None:
try:
self.message_handler.session_store.flush_pending_save()
except Exception as e:
logger.warning(f"Session store flush on shutdown: {e}")
if verbose:
logger.warning("Session store flush on shutdown: {}", e)
else:
logger.warning(
"Session store flush on shutdown: exc_type={}",
type(e).__name__,
)
logger.info("Shutdown requested, cleaning up...")
if self.messaging_platform:
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
await best_effort(
"messaging_platform.stop",
self.messaging_platform.stop(),
log_verbose_errors=verbose,
)
if self.cli_manager:
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
await best_effort(
"cli_manager.stop_all",
self.cli_manager.stop_all(),
log_verbose_errors=verbose,
)
if self._provider_registry is not None:
await best_effort(
"provider_registry.cleanup", self._provider_registry.cleanup()
"provider_registry.cleanup",
self._provider_registry.cleanup(),
log_verbose_errors=verbose,
)
await self._shutdown_limiter()
logger.info("Server shut down cleanly")
@ -110,6 +143,10 @@ class AppRuntime:
whisper_device=self.settings.whisper_device,
hf_token=self.settings.hf_token,
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
messaging_rate_limit=self.settings.messaging_rate_limit,
messaging_rate_window=self.settings.messaging_rate_window,
log_raw_messaging_content=self.settings.log_raw_messaging_content,
log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
),
)
@ -117,12 +154,24 @@ class AppRuntime:
await self._start_message_handler()
except ImportError as e:
logger.warning(f"Messaging module import error: {e}")
if self.settings.log_api_error_tracebacks:
logger.warning("Messaging module import error: {}", e)
else:
logger.warning(
"Messaging module import error: exc_type={}",
type(e).__name__,
)
except Exception as e:
logger.error(f"Failed to start messaging platform: {e}")
import traceback
if self.settings.log_api_error_tracebacks:
logger.error("Failed to start messaging platform: {}", e)
import traceback
logger.error(traceback.format_exc())
logger.error(traceback.format_exc())
else:
logger.error(
"Failed to start messaging platform: exc_type={}",
type(e).__name__,
)
async def _start_message_handler(self) -> None:
from cli.manager import CLISessionManager
@ -151,10 +200,13 @@ class AppRuntime:
allowed_dirs=allowed_dirs,
plans_directory=plans_directory,
claude_bin=self.settings.claude_cli_bin,
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
log_messaging_error_details=self.settings.log_messaging_error_details,
)
session_store = SessionStore(
storage_path=os.path.join(data_path, "sessions.json")
storage_path=os.path.join(data_path, "sessions.json"),
message_log_cap=self.settings.max_message_log_entries_per_chat,
)
platform = self.messaging_platform
assert platform is not None
@ -162,6 +214,11 @@ class AppRuntime:
platform=platform,
cli_manager=self.cli_manager,
session_store=session_store,
debug_platform_edits=self.settings.debug_platform_edits,
debug_subagent_stack=self.settings.debug_subagent_stack,
log_raw_messaging_content=self.settings.log_raw_messaging_content,
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
log_messaging_error_details=self.settings.log_messaging_error_details,
)
self._restore_tree_state(session_store)
@ -201,18 +258,26 @@ class AppRuntime:
self.app.state.cli_manager = self.cli_manager
async def _shutdown_limiter(self) -> None:
verbose = self.settings.log_api_error_tracebacks
try:
from messaging.limiter import MessagingRateLimiter
except Exception as e:
logger.debug(
"Rate limiter shutdown skipped (import failed): {}: {}",
type(e).__name__,
e,
)
if verbose:
logger.debug(
"Rate limiter shutdown skipped (import failed): {}: {}",
type(e).__name__,
e,
)
else:
logger.debug(
"Rate limiter shutdown skipped (import failed): exc_type={}",
type(e).__name__,
)
return
await best_effort(
"MessagingRateLimiter.shutdown_instance",
MessagingRateLimiter.shutdown_instance(),
timeout_s=2.0,
log_verbose_errors=verbose,
)

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import traceback
import uuid
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from typing import Any
from fastapi import HTTPException
@ -13,6 +13,7 @@ from loguru import logger
from config.settings import Settings
from core.anthropic import get_token_count, get_user_facing_error_message
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS
from providers.base import BaseProvider
from providers.exceptions import InvalidRequestError, ProviderError
@ -20,15 +21,67 @@ from .model_router import ModelRouter
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
from .web_server_tools import (
from .web_tools.egress import WebFetchEgressPolicy
from .web_tools.request import (
is_web_server_tool_request,
stream_web_server_tool_response,
openai_chat_upstream_server_tool_error,
)
from .web_tools.streaming import stream_web_server_tool_response
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
ProviderGetter = Callable[[str], BaseProvider]
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"})
def anthropic_sse_streaming_response(
body: AsyncIterator[str],
) -> StreamingResponse:
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
return StreamingResponse(
body,
media_type="text/event-stream",
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
)
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
"""HTTP status for uncaught non-provider failures (stable client contract)."""
return 500
def _log_unexpected_service_exception(
settings: Settings,
exc: BaseException,
*,
context: str,
request_id: str | None = None,
) -> None:
"""Log service-layer failures without echoing exception text unless opted in."""
if settings.log_api_error_tracebacks:
if request_id is not None:
logger.error("{} request_id={}: {}", context, request_id, exc)
else:
logger.error("{}: {}", context, exc)
logger.error(traceback.format_exc())
return
if request_id is not None:
logger.error(
"{} request_id={} exc_type={}",
context,
request_id,
type(exc).__name__,
)
else:
logger.error("{} exc_type={}", context, type(exc).__name__)
def _require_non_empty_messages(messages: list[Any]) -> None:
if not messages:
raise InvalidRequestError("messages cannot be empty")
class ClaudeProxyService:
"""Coordinate request optimization, model routing, token count, and providers."""
@ -48,25 +101,35 @@ class ClaudeProxyService:
def create_message(self, request_data: MessagesRequest) -> object:
"""Create a message response or streaming response."""
try:
if not request_data.messages:
raise InvalidRequestError("messages cannot be empty")
_require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_messages_request(request_data)
if is_web_server_tool_request(routed.request):
if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
tool_err = openai_chat_upstream_server_tool_error(
routed.request,
web_tools_enabled=self._settings.enable_web_server_tools,
)
if tool_err is not None:
raise InvalidRequestError(tool_err)
if self._settings.enable_web_server_tools and is_web_server_tool_request(
routed.request
):
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
logger.info("Optimization: Handling Anthropic web server tool")
return StreamingResponse(
egress = WebFetchEgressPolicy(
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
)
return anthropic_sse_streaming_response(
stream_web_server_tool_response(
routed.request, input_tokens=input_tokens
routed.request,
input_tokens=input_tokens,
web_fetch_egress=egress,
verbose_client_errors=self._settings.log_api_error_tracebacks,
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
optimized = try_optimizations(routed.request, self._settings)
@ -75,6 +138,10 @@ class ClaudeProxyService:
logger.debug("No optimization matched, routing to provider")
provider = self._provider_getter(routed.resolved.provider_id)
provider.preflight_stream(
routed.request,
thinking_enabled=routed.resolved.thinking_enabled,
)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
@ -83,34 +150,31 @@ class ClaudeProxyService:
routed.request.model,
len(routed.request.messages),
)
logger.debug(
"FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
)
if self._settings.log_raw_api_payloads:
logger.debug(
"FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
)
input_tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
)
return StreamingResponse(
return anthropic_sse_streaming_response(
provider.stream_response(
routed.request,
input_tokens=input_tokens,
request_id=request_id,
thinking_enabled=routed.resolved.thinking_enabled,
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except ProviderError:
raise
except Exception as e:
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
_log_unexpected_service_exception(
self._settings, e, context="CREATE_MESSAGE_ERROR"
)
raise HTTPException(
status_code=getattr(e, "status_code", 500),
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e
@ -119,6 +183,7 @@ class ClaudeProxyService:
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
_require_non_empty_messages(request_data.messages)
routed = self._model_router.resolve_token_count_request(request_data)
tokens = self._token_counter(
routed.request.messages, routed.request.system, routed.request.tools
@ -131,13 +196,16 @@ class ClaudeProxyService:
tokens,
)
return TokenCountResponse(input_tokens=tokens)
except ProviderError:
raise
except Exception as e:
logger.error(
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
request_id,
get_user_facing_error_message(e),
traceback.format_exc(),
_log_unexpected_service_exception(
self._settings,
e,
context="COUNT_TOKENS_ERROR",
request_id=request_id,
)
raise HTTPException(
status_code=500, detail=get_user_facing_error_message(e)
status_code=_http_status_for_unexpected_service_exception(e),
detail=get_user_facing_error_message(e),
) from e

48
api/validation_log.py Normal file
View file

@ -0,0 +1,48 @@
"""Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
from __future__ import annotations
from typing import Any
def summarize_request_validation_body(
body: Any,
) -> tuple[list[dict[str, Any]], list[str]]:
"""Return message shape summary and tool name list for debug logs."""
messages = body.get("messages") if isinstance(body, dict) else None
message_summary: list[dict[str, Any]] = []
if isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
message_summary.append({"message_kind": type(msg).__name__})
continue
content = msg.get("content")
item: dict[str, Any] = {
"role": msg.get("role"),
"content_kind": type(content).__name__,
}
if isinstance(content, list):
item["block_types"] = [
block.get("type", "dict")
if isinstance(block, dict)
else type(block).__name__
for block in content[:12]
]
item["block_keys"] = [
sorted(str(key) for key in block)[:12]
for block in content[:5]
if isinstance(block, dict)
]
elif isinstance(content, str):
item["content_length"] = len(content)
message_summary.append(item)
tool_names: list[str] = []
if isinstance(body, dict) and isinstance(body.get("tools"), list):
tool_names = [
str(tool.get("name", ""))
for tool in body["tools"]
if isinstance(tool, dict)
]
return message_summary, tool_names

View file

@ -1,331 +1,22 @@
"""Local handlers for Anthropic web server tools.
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
web tools are server-side: the API response itself must include the tool result.
"""
"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
from __future__ import annotations
import html
import json
import re
import uuid
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from html.parser import HTMLParser
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse
import httpx
from .models.anthropic import MessagesRequest
from api.web_tools.egress import (
WebFetchEgressPolicy,
WebFetchEgressViolation,
enforce_web_fetch_egress,
)
from api.web_tools.request import is_web_server_tool_request
from api.web_tools.streaming import stream_web_server_tool_response
_REQUEST_TIMEOUT_S = 20.0
_MAX_SEARCH_RESULTS = 10
_MAX_FETCH_CHARS = 24_000
class _SearchResultParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.results: list[dict[str, str]] = []
self._href: str | None = None
self._title_parts: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag != "a":
return
href = dict(attrs).get("href")
if not href or "uddg=" not in href:
return
parsed = urlparse(href)
query = parse_qs(parsed.query)
uddg = query.get("uddg", [""])[0]
if not uddg:
return
self._href = unquote(uddg)
self._title_parts = []
def handle_data(self, data: str) -> None:
if self._href is not None:
self._title_parts.append(data)
def handle_endtag(self, tag: str) -> None:
if tag != "a" or self._href is None:
return
title = " ".join("".join(self._title_parts).split())
if title and not any(result["url"] == self._href for result in self.results):
self.results.append({"title": html.unescape(title), "url": self._href})
self._href = None
self._title_parts = []
class _HTMLTextParser(HTMLParser):
def __init__(self) -> None:
super().__init__()
self.title = ""
self.text_parts: list[str] = []
self._in_title = False
self._skip_depth = 0
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag in {"script", "style", "noscript"}:
self._skip_depth += 1
elif tag == "title":
self._in_title = True
def handle_endtag(self, tag: str) -> None:
if tag in {"script", "style", "noscript"} and self._skip_depth:
self._skip_depth -= 1
elif tag == "title":
self._in_title = False
def handle_data(self, data: str) -> None:
text = " ".join(data.split())
if not text:
return
if self._in_title:
self.title = f"{self.title} {text}".strip()
elif not self._skip_depth:
self.text_parts.append(text)
def _format_event(event_type: str, data: dict[str, Any]) -> str:
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
def _content_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(str(item.get("text", "")))
else:
parts.append(str(getattr(item, "text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def _request_text(request: MessagesRequest) -> str:
return "\n".join(_content_text(message.content) for message in request.messages)
def _web_tool_name(request: MessagesRequest) -> str | None:
for tool in request.tools or []:
name = tool.name
tool_type = tool.type or ""
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
return name
return None
def is_web_server_tool_request(request: MessagesRequest) -> bool:
return _web_tool_name(request) in {"web_search", "web_fetch"}
def _extract_query(text: str) -> str:
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip().strip("\"'")
return text.strip()
def _extract_url(text: str) -> str:
match = re.search(r"https?://\S+", text)
return match.group(0).rstrip(").,]") if match else text.strip()
async def _run_web_search(query: str) -> list[dict[str, str]]:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
) as client:
response = await client.get(
"https://lite.duckduckgo.com/lite/",
params={"q": query},
)
response.raise_for_status()
parser = _SearchResultParser()
parser.feed(response.text)
return parser.results[:_MAX_SEARCH_RESULTS]
async def _run_web_fetch(url: str) -> dict[str, str]:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
) as client:
response = await client.get(url)
response.raise_for_status()
content_type = response.headers.get("content-type", "text/plain")
title = url
data = response.text
if "html" in content_type.lower():
parser = _HTMLTextParser()
parser.feed(response.text)
title = parser.title or url
data = "\n".join(parser.text_parts)
return {
"url": str(response.url),
"title": title,
"media_type": "text/plain",
"data": data[:_MAX_FETCH_CHARS],
}
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
if not results:
return f"No web search results found for: {query}"
lines = [f"Search results for: {query}"]
for index, result in enumerate(results, start=1):
lines.append(f"{index}. {result['title']}\n{result['url']}")
return "\n\n".join(lines)
async def stream_web_server_tool_response(
request: MessagesRequest, input_tokens: int
) -> AsyncIterator[str]:
tool_name = _web_tool_name(request)
if tool_name is None:
return
text = _request_text(request)
message_id = f"msg_{uuid.uuid4()}"
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
output_tokens = 1
usage_key = (
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
)
tool_input = (
{"query": _extract_query(text)}
if tool_name == "web_search"
else {"url": _extract_url(text)}
)
yield _format_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": request.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
},
},
)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": "server_tool_use",
"id": tool_id,
"name": tool_name,
"input": tool_input,
},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 0}
)
try:
if tool_name == "web_search":
query = str(tool_input["query"])
results = await _run_web_search(query)
result_content: Any = [
{
"type": "web_search_result",
"title": result["title"],
"url": result["url"],
}
for result in results
]
summary = _search_summary(query, results)
result_block_type = "web_search_tool_result"
else:
fetched = await _run_web_fetch(str(tool_input["url"]))
result_content = {
"type": "web_fetch_result",
"url": fetched["url"],
"content": {
"type": "document",
"source": {
"type": "text",
"media_type": fetched["media_type"],
"data": fetched["data"],
},
"title": fetched["title"],
"citations": {"enabled": True},
},
"retrieved_at": datetime.now(UTC).isoformat(),
}
summary = fetched["data"][:_MAX_FETCH_CHARS]
result_block_type = "web_fetch_tool_result"
except Exception as error:
result_block_type = (
"web_search_tool_result"
if tool_name == "web_search"
else "web_fetch_tool_result"
)
error_type = (
"web_search_tool_result_error"
if tool_name == "web_search"
else "web_fetch_tool_error"
)
result_content = {"type": error_type, "error_code": "unavailable"}
summary = f"{tool_name} failed: {type(error).__name__}"
output_tokens = max(1, len(summary) // 4)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 1,
"content_block": {
"type": result_block_type,
"tool_use_id": tool_id,
"content": result_content,
},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 1}
)
yield _format_event(
"content_block_start",
{
"type": "content_block_start",
"index": 2,
"content_block": {"type": "text", "text": summary},
},
)
yield _format_event(
"content_block_stop", {"type": "content_block_stop", "index": 2}
)
yield _format_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"server_tool_use": {usage_key: 1},
},
},
)
yield _format_event("message_stop", {"type": "message_stop"})
__all__ = [
"WebFetchEgressPolicy",
"WebFetchEgressViolation",
"enforce_web_fetch_egress",
"httpx",
"is_web_server_tool_request",
"stream_web_server_tool_response",
]

17
api/web_tools/__init__.py Normal file
View file

@ -0,0 +1,17 @@
"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
from .egress import (
WebFetchEgressPolicy,
WebFetchEgressViolation,
enforce_web_fetch_egress,
)
from .request import is_web_server_tool_request
from .streaming import stream_web_server_tool_response
__all__ = [
"WebFetchEgressPolicy",
"WebFetchEgressViolation",
"enforce_web_fetch_egress",
"is_web_server_tool_request",
"stream_web_server_tool_response",
]

View file

@ -0,0 +1,15 @@
"""Limits and defaults for outbound web server tool HTTP."""
_REQUEST_TIMEOUT_S = 20.0
_MAX_SEARCH_RESULTS = 10
_MAX_FETCH_CHARS = 24_000
# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
# Drain at most this many bytes from redirect responses before following Location.
_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
_MAX_WEB_FETCH_REDIRECTS = 10
_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
_WEB_TOOL_HTTP_HEADERS = {
"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
}

99
api/web_tools/egress.py Normal file
View file

@ -0,0 +1,99 @@
"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
@dataclass(frozen=True, slots=True)
class WebFetchEgressPolicy:
"""Egress rules for user-influenced web_fetch URLs."""
allow_private_network_targets: bool
allowed_schemes: frozenset[str]
class WebFetchEgressViolation(ValueError):
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
def _port_for_url(parsed) -> int:
if parsed.port is not None:
return parsed.port
return 443 if (parsed.scheme or "").lower() == "https" else 80
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
try:
return socket.getaddrinfo(
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
)
except OSError as exc:
raise WebFetchEgressViolation(
f"Could not resolve host {host!r}: {exc}"
) from exc
def get_validated_stream_addrinfos_for_egress(
url: str, policy: WebFetchEgressPolicy
) -> list[tuple]:
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
server cannot rebind to a disallowed address between resolution and the TCP
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
"""
parsed = urlparse(url)
scheme = (parsed.scheme or "").lower()
if scheme not in policy.allowed_schemes:
raise WebFetchEgressViolation(
f"URL scheme {scheme!r} is not allowed for web_fetch"
)
host = parsed.hostname
if host is None or host == "":
raise WebFetchEgressViolation("web_fetch URL must include a host")
port = _port_for_url(parsed)
if policy.allow_private_network_targets:
return _stream_getaddrinfo_or_raise(host, port)
host_lower = host.lower()
if host_lower == "localhost" or host_lower.endswith(".localhost"):
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
if host_lower.endswith(".local"):
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
try:
parsed_ip = ipaddress.ip_address(host)
except ValueError:
parsed_ip = None
if parsed_ip is not None:
if not parsed_ip.is_global:
raise WebFetchEgressViolation(
f"Non-public IP host {host!r} is not allowed for web_fetch"
)
return _stream_getaddrinfo_or_raise(host, port)
infos = _stream_getaddrinfo_or_raise(host, port)
for *_, sockaddr in infos:
addr = sockaddr[0]
try:
resolved = ipaddress.ip_address(addr)
except ValueError:
continue
if not resolved.is_global:
raise WebFetchEgressViolation(
f"Host {host!r} resolves to a non-public address ({resolved})"
)
return infos
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
get_validated_stream_addrinfos_for_egress(url, policy)

278
api/web_tools/outbound.py Normal file
View file

@ -0,0 +1,278 @@
"""Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
from __future__ import annotations
import asyncio
import socket
from collections.abc import AsyncIterator
from urllib.parse import urljoin, urlparse
import aiohttp
import httpx
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from aiohttp.abc import AbstractResolver, ResolveResult
from loguru import logger
from . import constants
from .constants import (
_MAX_FETCH_CHARS,
_MAX_SEARCH_RESULTS,
_REDIRECT_RESPONSE_BODY_CAP_BYTES,
_REQUEST_TIMEOUT_S,
_WEB_FETCH_REDIRECT_STATUSES,
_WEB_TOOL_HTTP_HEADERS,
)
from .egress import (
WebFetchEgressPolicy,
WebFetchEgressViolation,
get_validated_stream_addrinfos_for_egress,
)
from .parsers import HTMLTextParser, SearchResultParser
def _safe_public_host_for_logs(url: str) -> str:
host = urlparse(url).hostname or ""
return host[:253]
def _log_web_tool_failure(
tool_name: str,
error: BaseException,
*,
fetch_url: str | None = None,
) -> None:
exc_type = type(error).__name__
if isinstance(error, WebFetchEgressViolation):
host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
logger.warning(
"web_tool_egress_rejected tool={} exc_type={} host={!r}",
tool_name,
exc_type,
host,
)
return
if tool_name == "web_fetch" and fetch_url:
logger.warning(
"web_tool_failure tool={} exc_type={} host={!r}",
tool_name,
exc_type,
_safe_public_host_for_logs(fetch_url),
)
else:
logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
def _web_tool_client_error_summary(
tool_name: str,
error: BaseException,
*,
verbose: bool,
) -> str:
if verbose:
return f"{tool_name} failed: {type(error).__name__}"
return "Web tool request failed."
async def _iter_response_body_under_cap(
response: httpx.Response, max_bytes: int
) -> AsyncIterator[bytes]:
if max_bytes <= 0:
return
received = 0
async for chunk in response.aiter_bytes(chunk_size=65_536):
if received >= max_bytes:
break
remaining = max_bytes - received
if len(chunk) <= remaining:
received += len(chunk)
yield chunk
if received >= max_bytes:
break
else:
yield chunk[:remaining]
break
async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
async for _ in _iter_response_body_under_cap(response, max_bytes):
pass
async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
return b"".join(
[piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
)
_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
def getaddrinfo_rows_to_resolve_results(
host: str, addrinfos: list[tuple]
) -> list[ResolveResult]:
"""Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
out: list[ResolveResult] = []
for family, _type, proto, _canon, sockaddr in addrinfos:
if family == socket.AF_INET6:
if len(sockaddr) < 3:
continue
if sockaddr[3]:
resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
else:
resolved_host, port = sockaddr[:2]
else:
assert family == socket.AF_INET, family
resolved_host, port = sockaddr[0], sockaddr[1]
resolved_host = str(resolved_host)
port = int(port)
out.append(
ResolveResult(
hostname=host,
host=resolved_host,
port=int(port),
family=family,
proto=proto,
flags=_NUMERIC_RESOLVE_FLAGS,
)
)
return out
class _PinnedEgressStaticResolver(AbstractResolver):
"""Return only pre-validated :class:`ResolveResult` for the outbound request."""
def __init__(self, results: list[ResolveResult]) -> None:
self._results = results
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> list[ResolveResult]:
return self._results
async def close(self) -> None: # pragma: no cover - aiohttp contract
return
async def _read_aiohttp_body_capped(
response: aiohttp.ClientResponse, max_bytes: int
) -> bytes:
received = 0
parts: list[bytes] = []
async for chunk in response.content.iter_chunked(65_536):
if received >= max_bytes:
break
remaining = max_bytes - received
if len(chunk) <= remaining:
received += len(chunk)
parts.append(chunk)
else:
parts.append(chunk[:remaining])
break
return b"".join(parts)
async def _drain_aiohttp_body_capped(
response: aiohttp.ClientResponse, max_bytes: int
) -> None:
if max_bytes <= 0:
return
received = 0
async for chunk in response.content.iter_chunked(65_536):
received += len(chunk)
if received >= max_bytes:
break
async def _run_web_search(query: str) -> list[dict[str, str]]:
async with (
httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT_S,
follow_redirects=True,
headers=_WEB_TOOL_HTTP_HEADERS,
) as client,
client.stream(
"GET",
"https://lite.duckduckgo.com/lite/",
params={"q": query},
) as response,
):
response.raise_for_status()
body_bytes = await _read_response_body_capped(
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
)
text = body_bytes.decode("utf-8", errors="replace")
parser = SearchResultParser()
parser.feed(text)
return parser.results[:_MAX_SEARCH_RESULTS]
async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
"""Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
current_url = url
redirect_hops = 0
timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
while True:
addr_infos = await asyncio.to_thread(
get_validated_stream_addrinfos_for_egress, current_url, egress
)
host = urlparse(current_url).hostname or ""
results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
resolver = _PinnedEgressStaticResolver(results)
connector = TCPConnector(
resolver=resolver,
force_close=True,
)
try:
async with (
ClientSession(
timeout=timeout,
headers=_WEB_TOOL_HTTP_HEADERS,
connector=connector,
) as session,
session.get(current_url, allow_redirects=False) as response,
):
if response.status in _WEB_FETCH_REDIRECT_STATUSES:
await _drain_aiohttp_body_capped(
response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
)
if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
raise WebFetchEgressViolation(
"web_fetch exceeded maximum redirects "
f"({constants._MAX_WEB_FETCH_REDIRECTS})"
)
location = response.headers.get("location")
if not location or not location.strip():
raise WebFetchEgressViolation(
"web_fetch redirect response missing Location header"
)
current_url = urljoin(str(response.url), location.strip())
redirect_hops += 1
continue
response.raise_for_status()
content_type = response.headers.get("content-type", "text/plain")
final_url = str(response.url)
encoding = response.get_encoding() or "utf-8"
body_bytes = await _read_aiohttp_body_capped(
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
)
finally:
await connector.close()
break
text = body_bytes.decode(encoding, errors="replace")
title = final_url
data = text
if "html" in content_type.lower():
parser = HTMLTextParser()
parser.feed(text)
title = parser.title or final_url
data = "\n".join(parser.text_parts)
return {
"url": final_url,
"title": title,
"media_type": "text/plain",
"data": data[:_MAX_FETCH_CHARS],
}

104
api/web_tools/parsers.py Normal file
View file

@ -0,0 +1,104 @@
"""HTML parsing for web_search / web_fetch."""
from __future__ import annotations
import html
import re
from html.parser import HTMLParser
from typing import Any
from urllib.parse import parse_qs, unquote, urlparse
class SearchResultParser(HTMLParser):
"""DuckDuckGo lite HTML: extract result links and titles."""
def __init__(self) -> None:
super().__init__()
self.results: list[dict[str, str]] = []
self._href: str | None = None
self._title_parts: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag != "a":
return
href = dict(attrs).get("href")
if not href or "uddg=" not in href:
return
parsed = urlparse(href)
query = parse_qs(parsed.query)
uddg = query.get("uddg", [""])[0]
if not uddg:
return
self._href = unquote(uddg)
self._title_parts = []
def handle_data(self, data: str) -> None:
if self._href is not None:
self._title_parts.append(data)
def handle_endtag(self, tag: str) -> None:
if tag != "a" or self._href is None:
return
title = " ".join("".join(self._title_parts).split())
if title and not any(result["url"] == self._href for result in self.results):
self.results.append({"title": html.unescape(title), "url": self._href})
self._href = None
self._title_parts = []
class HTMLTextParser(HTMLParser):
"""Strip scripts/styles and collect visible text + title for fetch previews."""
def __init__(self) -> None:
super().__init__()
self.title = ""
self.text_parts: list[str] = []
self._in_title = False
self._skip_depth = 0
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag in {"script", "style", "noscript"}:
self._skip_depth += 1
elif tag == "title":
self._in_title = True
def handle_endtag(self, tag: str) -> None:
if tag in {"script", "style", "noscript"} and self._skip_depth:
self._skip_depth -= 1
elif tag == "title":
self._in_title = False
def handle_data(self, data: str) -> None:
text = " ".join(data.split())
if not text:
return
if self._in_title:
self.title = f"{self.title} {text}".strip()
elif not self._skip_depth:
self.text_parts.append(text)
def content_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(str(item.get("text", "")))
else:
parts.append(str(getattr(item, "text", "")))
return "\n".join(part for part in parts if part)
return str(content)
def extract_query(text: str) -> str:
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip().strip("\"'")
return text.strip()
def extract_url(text: str) -> str:
match = re.search(r"https?://\S+", text)
return match.group(0).rstrip(").,]") if match else text.strip()

87
api/web_tools/request.py Normal file
View file

@ -0,0 +1,87 @@
"""Detect forced Anthropic web server tool requests."""
from __future__ import annotations
from api.models.anthropic import MessagesRequest, Tool
def request_text(request: MessagesRequest) -> str:
"""Join all user/assistant message content into one string for tool input parsing."""
from .parsers import content_text
return "\n".join(content_text(message.content) for message in request.messages)
def forced_tool_turn_text(request: MessagesRequest) -> str:
"""Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
if not request.messages:
return ""
from .parsers import content_text
for message in reversed(request.messages):
if message.role == "user":
return content_text(message.content)
return ""
def forced_server_tool_name(request: MessagesRequest) -> str | None:
"""Return web_search or web_fetch only when tool_choice forces that server tool."""
tc = request.tool_choice
if not isinstance(tc, dict):
return None
if tc.get("type") != "tool":
return None
name = tc.get("name")
if name in {"web_search", "web_fetch"}:
return str(name)
return None
def has_tool_named(request: MessagesRequest, name: str) -> bool:
return any(tool.name == name for tool in request.tools or [])
def is_web_server_tool_request(request: MessagesRequest) -> bool:
"""True when the client forces a web server tool via tool_choice (not merely listed)."""
forced = forced_server_tool_name(request)
if forced is None:
return False
return has_tool_named(request, forced)
def is_anthropic_server_tool_definition(tool: Tool) -> bool:
"""Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
name = (tool.name or "").strip()
if name in ("web_search", "web_fetch"):
return True
typ = tool.type
if isinstance(typ, str):
return typ.startswith("web_search") or typ.startswith("web_fetch")
return False
def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
"""True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
def openai_chat_upstream_server_tool_error(
request: MessagesRequest, *, web_tools_enabled: bool
) -> str | None:
"""Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
forced = forced_server_tool_name(request)
if forced and not web_tools_enabled:
return (
f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
"disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic "
"Messages transport (e.g. open_router, ollama, lmstudio)."
)
if not forced and has_listed_anthropic_server_tools(request):
return (
"OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools "
"(web_search / web_fetch) without the local web server tool handler. Use a native "
"Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
"tool_choice, or remove these tools from the request."
)
return None

206
api/web_tools/streaming.py Normal file
View file

@ -0,0 +1,206 @@
"""SSE streaming for local web_search / web_fetch server tool results."""
from __future__ import annotations
import uuid
from collections.abc import AsyncIterator
from datetime import UTC, datetime
from typing import Any
from api.models.anthropic import MessagesRequest
from core.anthropic.server_tool_sse import (
SERVER_TOOL_USE,
WEB_FETCH_TOOL_ERROR,
WEB_FETCH_TOOL_RESULT,
WEB_SEARCH_TOOL_RESULT,
WEB_SEARCH_TOOL_RESULT_ERROR,
)
from core.anthropic.sse import format_sse_event
from . import outbound
from .constants import _MAX_FETCH_CHARS
from .egress import WebFetchEgressPolicy
from .parsers import extract_query, extract_url
from .request import (
forced_server_tool_name,
forced_tool_turn_text,
has_tool_named,
)
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
if not results:
return f"No web search results found for: {query}"
lines = [f"Search results for: {query}"]
for index, result in enumerate(results, start=1):
lines.append(f"{index}. {result['title']}\n{result['url']}")
return "\n\n".join(lines)
async def stream_web_server_tool_response(
request: MessagesRequest,
input_tokens: int,
*,
web_fetch_egress: WebFetchEgressPolicy,
verbose_client_errors: bool = False,
) -> AsyncIterator[str]:
"""Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path not a full
hosted Anthropic citation or encrypted-content pipeline.
"""
tool_name = forced_server_tool_name(request)
if tool_name is None or not has_tool_named(request, tool_name):
return
text = forced_tool_turn_text(request)
message_id = f"msg_{uuid.uuid4()}"
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
usage_key = (
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
)
tool_input = (
{"query": extract_query(text)}
if tool_name == "web_search"
else {"url": extract_url(text)}
)
_result_block_for_tool = {
"web_search": WEB_SEARCH_TOOL_RESULT,
"web_fetch": WEB_FETCH_TOOL_RESULT,
}
_error_payload_type_for_tool = {
"web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
"web_fetch": WEB_FETCH_TOOL_ERROR,
}
yield format_sse_event(
"message_start",
{
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": request.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
},
},
)
yield format_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {
"type": SERVER_TOOL_USE,
"id": tool_id,
"name": tool_name,
"input": tool_input,
},
},
)
yield format_sse_event(
"content_block_stop", {"type": "content_block_stop", "index": 0}
)
try:
if tool_name == "web_search":
query = str(tool_input["query"])
results = await outbound._run_web_search(query)
result_content: Any = [
{
"type": "web_search_result",
"title": result["title"],
"url": result["url"],
}
for result in results
]
summary = _search_summary(query, results)
result_block_type = WEB_SEARCH_TOOL_RESULT
else:
fetched = await outbound._run_web_fetch(
str(tool_input["url"]), web_fetch_egress
)
result_content = {
"type": "web_fetch_result",
"url": fetched["url"],
"content": {
"type": "document",
"source": {
"type": "text",
"media_type": fetched["media_type"],
"data": fetched["data"],
},
"title": fetched["title"],
"citations": {"enabled": True},
},
"retrieved_at": datetime.now(UTC).isoformat(),
}
summary = fetched["data"][:_MAX_FETCH_CHARS]
result_block_type = WEB_FETCH_TOOL_RESULT
except Exception as error:
fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
result_block_type = _result_block_for_tool[tool_name]
result_content = {
"type": _error_payload_type_for_tool[tool_name],
"error_code": "unavailable",
}
summary = outbound._web_tool_client_error_summary(
tool_name, error, verbose=verbose_client_errors
)
output_tokens = max(1, len(summary) // 4)
yield format_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": 1,
"content_block": {
"type": result_block_type,
"tool_use_id": tool_id,
"content": result_content,
},
},
)
yield format_sse_event(
"content_block_stop", {"type": "content_block_stop", "index": 1}
)
# Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
# not eager `text` on `content_block_start`).
yield format_sse_event(
"content_block_start",
{
"type": "content_block_start",
"index": 2,
"content_block": {"type": "text", "text": ""},
},
)
yield format_sse_event(
"content_block_delta",
{
"type": "content_block_delta",
"index": 2,
"delta": {"type": "text_delta", "text": summary},
},
)
yield format_sse_event(
"content_block_stop", {"type": "content_block_stop", "index": 2}
)
yield format_sse_event(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"server_tool_use": {usage_key: 1},
},
},
)
yield format_sse_event("message_stop", {"type": "message_stop"})