mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Major refactor: API, providers, messaging, and Anthropic protocol
Some checks are pending
CI / checks (push) Waiting to run
Some checks are pending
CI / checks (push) Waiting to run
Consolidates the incremental refactor work into a single change set: modular web tools (api/web_tools), native Anthropic request building and SSE block policy, OpenAI conversion and error handling, provider transports and rate limiting, messaging handler and tree queue, safe logging, smoke tests, and broad test coverage.
This commit is contained in:
parent
b9ed704095
commit
f3a7528d49
139 changed files with 7460 additions and 2422 deletions
|
|
@ -1,6 +1,6 @@
|
|||
"""API layer for Claude Code Proxy."""
|
||||
|
||||
from .app import app, create_app
|
||||
from .app import create_app
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
|
|
@ -13,6 +13,5 @@ __all__ = [
|
|||
"MessagesResponse",
|
||||
"TokenCountRequest",
|
||||
"TokenCountResponse",
|
||||
"app",
|
||||
"create_app",
|
||||
]
|
||||
|
|
|
|||
85
api/app.py
85
api/app.py
|
|
@ -1,6 +1,6 @@
|
|||
"""FastAPI application factory and configuration."""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -16,13 +16,7 @@ from providers.exceptions import ProviderError
|
|||
|
||||
from .routes import router
|
||||
from .runtime import AppRuntime
|
||||
|
||||
# Opt-in to future behavior for python-telegram-bot
|
||||
os.environ["PTB_TIMEDELTA"] = "1"
|
||||
|
||||
# Configure logging first (before any module logs)
|
||||
_settings = get_settings()
|
||||
configure_logging(_settings.log_file)
|
||||
from .validation_log import summarize_request_validation_body
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -38,6 +32,11 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
settings = get_settings()
|
||||
configure_logging(
|
||||
settings.log_file, verbose_third_party=settings.log_raw_api_payloads
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="Claude Code Proxy",
|
||||
version="2.0.0",
|
||||
|
|
@ -57,33 +56,7 @@ def create_app() -> FastAPI:
|
|||
except Exception as e:
|
||||
body = {"_json_error": type(e).__name__}
|
||||
|
||||
messages = body.get("messages") if isinstance(body, dict) else None
|
||||
message_summary: list[dict[str, Any]] = []
|
||||
if isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
message_summary.append({"message_kind": type(msg).__name__})
|
||||
continue
|
||||
content = msg.get("content")
|
||||
item: dict[str, Any] = {
|
||||
"role": msg.get("role"),
|
||||
"content_kind": type(content).__name__,
|
||||
}
|
||||
if isinstance(content, list):
|
||||
item["block_types"] = [
|
||||
block.get("type", "dict")
|
||||
if isinstance(block, dict)
|
||||
else type(block).__name__
|
||||
for block in content[:12]
|
||||
]
|
||||
item["block_keys"] = [
|
||||
sorted(str(key) for key in block)[:12]
|
||||
for block in content[:5]
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
elif isinstance(content, str):
|
||||
item["content_length"] = len(content)
|
||||
message_summary.append(item)
|
||||
message_summary, tool_names = summarize_request_validation_body(body)
|
||||
|
||||
logger.debug(
|
||||
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
|
||||
|
|
@ -92,20 +65,27 @@ def create_app() -> FastAPI:
|
|||
[list(error.get("loc", ())) for error in exc.errors()],
|
||||
[str(error.get("type", "")) for error in exc.errors()],
|
||||
message_summary,
|
||||
[
|
||||
str(tool.get("name", ""))
|
||||
for tool in body.get("tools", [])
|
||||
if isinstance(body, dict)
|
||||
and isinstance(body.get("tools"), list)
|
||||
and isinstance(tool, dict)
|
||||
],
|
||||
tool_names,
|
||||
)
|
||||
return await request_validation_exception_handler(request, exc)
|
||||
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
"""Handle provider-specific errors and return Anthropic format."""
|
||||
logger.error(f"Provider Error: {exc.error_type} - {exc.message}")
|
||||
err_settings = get_settings()
|
||||
if err_settings.log_api_error_tracebacks:
|
||||
logger.error(
|
||||
"Provider Error: error_type={} status_code={} message={}",
|
||||
exc.error_type,
|
||||
exc.status_code,
|
||||
exc.message,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Provider Error: error_type={} status_code={}",
|
||||
exc.error_type,
|
||||
exc.status_code,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=exc.to_anthropic_format(),
|
||||
|
|
@ -114,10 +94,17 @@ def create_app() -> FastAPI:
|
|||
@app.exception_handler(Exception)
|
||||
async def general_error_handler(request: Request, exc: Exception):
|
||||
"""Handle general errors and return Anthropic format."""
|
||||
logger.error(f"General Error: {exc!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
settings = get_settings()
|
||||
if settings.log_api_error_tracebacks:
|
||||
logger.error("General Error: {}", exc)
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.error(
|
||||
"General Error: path={} method={} exc_type={}",
|
||||
request.url.path,
|
||||
request.method,
|
||||
type(exc).__name__,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
|
|
@ -130,7 +117,3 @@ def create_app() -> FastAPI:
|
|||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Default app instance for uvicorn
|
||||
app = create_app()
|
||||
|
|
|
|||
|
|
@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str:
|
|||
|
||||
return "<filepaths>\n</filepaths>"
|
||||
|
||||
except Exception:
|
||||
except ValueError:
|
||||
return "<filepaths>\n</filepaths>"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from config.settings import Settings
|
|||
from config.settings import get_settings as _get_settings
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from providers.base import BaseProvider
|
||||
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||
from providers.exceptions import (
|
||||
AuthenticationError,
|
||||
ServiceUnavailableError,
|
||||
UnknownProviderTypeError,
|
||||
)
|
||||
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
||||
|
||||
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
|
||||
|
|
@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {}
|
|||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings via dependency injection."""
|
||||
"""Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
|
||||
return _get_settings()
|
||||
|
||||
|
||||
|
|
@ -31,10 +35,9 @@ def resolve_provider(
|
|||
"""Resolve a provider using the app-scoped registry when ``app`` is set.
|
||||
|
||||
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
|
||||
is always used. If the registry is missing (e.g. a test app without
|
||||
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
|
||||
is installed on ``app.state`` so the process cache is never mixed with
|
||||
per-request app identity.
|
||||
must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
|
||||
Callers that construct a bare ``FastAPI`` without lifespan must set
|
||||
``app.state.provider_registry`` explicitly.
|
||||
|
||||
When ``app`` is ``None`` (no HTTP context), uses the process-level
|
||||
:data:`_providers` cache only.
|
||||
|
|
@ -42,8 +45,10 @@ def resolve_provider(
|
|||
if app is not None:
|
||||
reg = getattr(app.state, "provider_registry", None)
|
||||
if reg is None:
|
||||
reg = ProviderRegistry()
|
||||
app.state.provider_registry = reg
|
||||
raise ServiceUnavailableError(
|
||||
"Provider registry is not configured. Ensure AppRuntime startup ran "
|
||||
"or assign app.state.provider_registry for test apps."
|
||||
)
|
||||
return _resolve_with_registry(reg, provider_type, settings)
|
||||
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
|
||||
|
||||
|
|
@ -55,9 +60,10 @@ def _resolve_with_registry(
|
|||
try:
|
||||
provider = registry.get(provider_type, settings)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
# Provider :class:`~providers.exceptions.AuthenticationError` messages are
|
||||
# curated configuration hints (env var names, docs links), not upstream noise.
|
||||
detail = str(e).strip() or get_user_facing_error_message(e)
|
||||
raise HTTPException(status_code=503, detail=detail) from e
|
||||
except UnknownProviderTypeError:
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: {}",
|
||||
|
|
@ -73,8 +79,9 @@ def _resolve_with_registry(
|
|||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider in the process-level cache (no ``app``/Request).
|
||||
|
||||
For server requests, use :func:`resolve_provider` with the active
|
||||
:attr:`request.app` so the app-scoped provider registry is used.
|
||||
HTTP route handlers should call :func:`resolve_provider` with the active
|
||||
:attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
|
||||
process-wide cache.
|
||||
"""
|
||||
return resolve_provider(provider_type, app=None, settings=get_settings())
|
||||
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st
|
|||
try:
|
||||
cmd_start = content.rfind("Command:") + len("Command:")
|
||||
return True, content[cmd_start:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
except TypeError:
|
||||
return False, ""
|
||||
|
||||
return False, ""
|
||||
|
||||
|
|
@ -121,19 +121,16 @@ def is_filepath_extraction_request(
|
|||
if not user_has_filepaths and not system_has_extract:
|
||||
return False, "", ""
|
||||
|
||||
try:
|
||||
cmd_start = content.find("Command:") + len("Command:")
|
||||
output_marker = content.find("Output:", cmd_start)
|
||||
if output_marker == -1:
|
||||
return False, "", ""
|
||||
|
||||
command = content[cmd_start:output_marker].strip()
|
||||
output = content[output_marker + len("Output:") :].strip()
|
||||
|
||||
for marker in ["<", "\n\n"]:
|
||||
if marker in output:
|
||||
output = output.split(marker)[0].strip()
|
||||
|
||||
return True, command, output
|
||||
except Exception:
|
||||
cmd_start = content.find("Command:") + len("Command:")
|
||||
output_marker = content.find("Output:", cmd_start)
|
||||
if output_marker == -1:
|
||||
return False, "", ""
|
||||
|
||||
command = content[cmd_start:output_marker].strip()
|
||||
output = content[output_marker + len("Output:") :].strip()
|
||||
|
||||
for marker in ["<", "\n\n"]:
|
||||
if marker in output:
|
||||
output = output.split(marker)[0].strip()
|
||||
|
||||
return True, command, output
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -15,41 +15,68 @@ class Role(StrEnum):
|
|||
system = "system"
|
||||
|
||||
|
||||
class ContentBlockText(BaseModel):
|
||||
class _AnthropicBlockBase(BaseModel):
|
||||
"""Pass through provider fields (e.g. ``cache_control``) for native transports."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ContentBlockText(_AnthropicBlockBase):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ContentBlockImage(BaseModel):
|
||||
class ContentBlockImage(_AnthropicBlockBase):
|
||||
type: Literal["image"]
|
||||
source: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockToolUse(BaseModel):
|
||||
class ContentBlockToolUse(_AnthropicBlockBase):
|
||||
type: Literal["tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockToolResult(BaseModel):
|
||||
class ContentBlockToolResult(_AnthropicBlockBase):
|
||||
type: Literal["tool_result"]
|
||||
tool_use_id: str
|
||||
content: str | list[Any] | dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockThinking(BaseModel):
|
||||
class ContentBlockThinking(_AnthropicBlockBase):
|
||||
type: Literal["thinking"]
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ContentBlockRedactedThinking(BaseModel):
|
||||
class ContentBlockRedactedThinking(_AnthropicBlockBase):
|
||||
type: Literal["redacted_thinking"]
|
||||
data: str
|
||||
|
||||
|
||||
class SystemContent(BaseModel):
|
||||
class ContentBlockServerToolUse(_AnthropicBlockBase):
|
||||
"""Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
|
||||
|
||||
type: Literal["server_tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
|
||||
type: Literal["web_search_tool_result"]
|
||||
tool_use_id: str
|
||||
content: Any
|
||||
|
||||
|
||||
class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
|
||||
type: Literal["web_fetch_tool_result"]
|
||||
tool_use_id: str
|
||||
content: Any
|
||||
|
||||
|
||||
class SystemContent(_AnthropicBlockBase):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
|
@ -68,12 +95,15 @@ class Message(BaseModel):
|
|||
| ContentBlockToolResult
|
||||
| ContentBlockThinking
|
||||
| ContentBlockRedactedThinking
|
||||
| ContentBlockServerToolUse
|
||||
| ContentBlockWebSearchToolResult
|
||||
| ContentBlockWebFetchToolResult
|
||||
]
|
||||
)
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
class Tool(_AnthropicBlockBase):
|
||||
name: str
|
||||
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
|
||||
# may omit ``input_schema`` because the provider owns the schema.
|
||||
|
|
@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel):
|
|||
# Request Models
|
||||
# =============================================================================
|
||||
class MessagesRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
model: str
|
||||
# Internal routing / debug: accepted on parse but not serialized to providers.
|
||||
original_model: str | None = Field(default=None, exclude=True)
|
||||
resolved_provider_model: str | None = Field(default=None, exclude=True)
|
||||
max_tokens: int | None = None
|
||||
messages: list[Message]
|
||||
system: str | list[SystemContent] | None = None
|
||||
|
|
@ -105,13 +140,24 @@ class MessagesRequest(BaseModel):
|
|||
tools: list[Tool] | None = None
|
||||
tool_choice: dict[str, Any] | None = None
|
||||
thinking: ThinkingConfig | None = None
|
||||
# Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
|
||||
context_management: dict[str, Any] | None = None
|
||||
output_config: dict[str, Any] | None = None
|
||||
mcp_servers: list[dict[str, Any]] | None = None
|
||||
extra_body: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
model: str
|
||||
original_model: str | None = Field(default=None, exclude=True)
|
||||
resolved_provider_model: str | None = Field(default=None, exclude=True)
|
||||
messages: list[Message]
|
||||
system: str | list[SystemContent] | None = None
|
||||
tools: list[Tool] | None = None
|
||||
thinking: ThinkingConfig | None = None
|
||||
tool_choice: dict[str, Any] | None = None
|
||||
context_management: dict[str, Any] | None = None
|
||||
output_config: dict[str, Any] | None = None
|
||||
mcp_servers: list[dict[str, Any]] | None = None
|
||||
|
|
|
|||
|
|
@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0
|
|||
|
||||
|
||||
async def best_effort(
|
||||
name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S
|
||||
name: str,
|
||||
awaitable: Any,
|
||||
timeout_s: float = _SHUTDOWN_TIMEOUT_S,
|
||||
*,
|
||||
log_verbose_errors: bool = False,
|
||||
) -> None:
|
||||
"""Run a shutdown step with timeout; never raise to callers."""
|
||||
try:
|
||||
await asyncio.wait_for(awaitable, timeout=timeout_s)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
|
||||
logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
|
||||
except Exception as e:
|
||||
logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
|
||||
if log_verbose_errors:
|
||||
logger.warning(
|
||||
"Shutdown step failed: {}: {}: {}",
|
||||
name,
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Shutdown step failed: {}: exc_type={}",
|
||||
name,
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
|
||||
def warn_if_process_auth_token(settings: Settings) -> None:
|
||||
|
|
@ -73,20 +89,37 @@ class AppRuntime:
|
|||
self._publish_state()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
verbose = self.settings.log_api_error_tracebacks
|
||||
if self.message_handler is not None:
|
||||
try:
|
||||
self.message_handler.session_store.flush_pending_save()
|
||||
except Exception as e:
|
||||
logger.warning(f"Session store flush on shutdown: {e}")
|
||||
if verbose:
|
||||
logger.warning("Session store flush on shutdown: {}", e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Session store flush on shutdown: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
logger.info("Shutdown requested, cleaning up...")
|
||||
if self.messaging_platform:
|
||||
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
|
||||
await best_effort(
|
||||
"messaging_platform.stop",
|
||||
self.messaging_platform.stop(),
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
if self.cli_manager:
|
||||
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
|
||||
await best_effort(
|
||||
"cli_manager.stop_all",
|
||||
self.cli_manager.stop_all(),
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
if self._provider_registry is not None:
|
||||
await best_effort(
|
||||
"provider_registry.cleanup", self._provider_registry.cleanup()
|
||||
"provider_registry.cleanup",
|
||||
self._provider_registry.cleanup(),
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
await self._shutdown_limiter()
|
||||
logger.info("Server shut down cleanly")
|
||||
|
|
@ -110,6 +143,10 @@ class AppRuntime:
|
|||
whisper_device=self.settings.whisper_device,
|
||||
hf_token=self.settings.hf_token,
|
||||
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
|
||||
messaging_rate_limit=self.settings.messaging_rate_limit,
|
||||
messaging_rate_window=self.settings.messaging_rate_window,
|
||||
log_raw_messaging_content=self.settings.log_raw_messaging_content,
|
||||
log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -117,12 +154,24 @@ class AppRuntime:
|
|||
await self._start_message_handler()
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Messaging module import error: {e}")
|
||||
if self.settings.log_api_error_tracebacks:
|
||||
logger.warning("Messaging module import error: {}", e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Messaging module import error: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start messaging platform: {e}")
|
||||
import traceback
|
||||
if self.settings.log_api_error_tracebacks:
|
||||
logger.error("Failed to start messaging platform: {}", e)
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to start messaging platform: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
async def _start_message_handler(self) -> None:
|
||||
from cli.manager import CLISessionManager
|
||||
|
|
@ -151,10 +200,13 @@ class AppRuntime:
|
|||
allowed_dirs=allowed_dirs,
|
||||
plans_directory=plans_directory,
|
||||
claude_bin=self.settings.claude_cli_bin,
|
||||
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
|
||||
log_messaging_error_details=self.settings.log_messaging_error_details,
|
||||
)
|
||||
|
||||
session_store = SessionStore(
|
||||
storage_path=os.path.join(data_path, "sessions.json")
|
||||
storage_path=os.path.join(data_path, "sessions.json"),
|
||||
message_log_cap=self.settings.max_message_log_entries_per_chat,
|
||||
)
|
||||
platform = self.messaging_platform
|
||||
assert platform is not None
|
||||
|
|
@ -162,6 +214,11 @@ class AppRuntime:
|
|||
platform=platform,
|
||||
cli_manager=self.cli_manager,
|
||||
session_store=session_store,
|
||||
debug_platform_edits=self.settings.debug_platform_edits,
|
||||
debug_subagent_stack=self.settings.debug_subagent_stack,
|
||||
log_raw_messaging_content=self.settings.log_raw_messaging_content,
|
||||
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
|
||||
log_messaging_error_details=self.settings.log_messaging_error_details,
|
||||
)
|
||||
self._restore_tree_state(session_store)
|
||||
|
||||
|
|
@ -201,18 +258,26 @@ class AppRuntime:
|
|||
self.app.state.cli_manager = self.cli_manager
|
||||
|
||||
async def _shutdown_limiter(self) -> None:
|
||||
verbose = self.settings.log_api_error_tracebacks
|
||||
try:
|
||||
from messaging.limiter import MessagingRateLimiter
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Rate limiter shutdown skipped (import failed): {}: {}",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
if verbose:
|
||||
logger.debug(
|
||||
"Rate limiter shutdown skipped (import failed): {}: {}",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Rate limiter shutdown skipped (import failed): exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
return
|
||||
|
||||
await best_effort(
|
||||
"MessagingRateLimiter.shutdown_instance",
|
||||
MessagingRateLimiter.shutdown_instance(),
|
||||
timeout_s=2.0,
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
|
|
|
|||
132
api/services.py
132
api/services.py
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -13,6 +13,7 @@ from loguru import logger
|
|||
|
||||
from config.settings import Settings
|
||||
from core.anthropic import get_token_count, get_user_facing_error_message
|
||||
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS
|
||||
from providers.base import BaseProvider
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
|
||||
|
|
@ -20,15 +21,67 @@ from .model_router import ModelRouter
|
|||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
from .web_server_tools import (
|
||||
from .web_tools.egress import WebFetchEgressPolicy
|
||||
from .web_tools.request import (
|
||||
is_web_server_tool_request,
|
||||
stream_web_server_tool_response,
|
||||
openai_chat_upstream_server_tool_error,
|
||||
)
|
||||
from .web_tools.streaming import stream_web_server_tool_response
|
||||
|
||||
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
||||
|
||||
ProviderGetter = Callable[[str], BaseProvider]
|
||||
|
||||
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
|
||||
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"})
|
||||
|
||||
|
||||
def anthropic_sse_streaming_response(
|
||||
body: AsyncIterator[str],
|
||||
) -> StreamingResponse:
|
||||
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
|
||||
return StreamingResponse(
|
||||
body,
|
||||
media_type="text/event-stream",
|
||||
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
|
||||
"""HTTP status for uncaught non-provider failures (stable client contract)."""
|
||||
return 500
|
||||
|
||||
|
||||
def _log_unexpected_service_exception(
|
||||
settings: Settings,
|
||||
exc: BaseException,
|
||||
*,
|
||||
context: str,
|
||||
request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Log service-layer failures without echoing exception text unless opted in."""
|
||||
if settings.log_api_error_tracebacks:
|
||||
if request_id is not None:
|
||||
logger.error("{} request_id={}: {}", context, request_id, exc)
|
||||
else:
|
||||
logger.error("{}: {}", context, exc)
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
if request_id is not None:
|
||||
logger.error(
|
||||
"{} request_id={} exc_type={}",
|
||||
context,
|
||||
request_id,
|
||||
type(exc).__name__,
|
||||
)
|
||||
else:
|
||||
logger.error("{} exc_type={}", context, type(exc).__name__)
|
||||
|
||||
|
||||
def _require_non_empty_messages(messages: list[Any]) -> None:
|
||||
if not messages:
|
||||
raise InvalidRequestError("messages cannot be empty")
|
||||
|
||||
|
||||
class ClaudeProxyService:
|
||||
"""Coordinate request optimization, model routing, token count, and providers."""
|
||||
|
|
@ -48,25 +101,35 @@ class ClaudeProxyService:
|
|||
def create_message(self, request_data: MessagesRequest) -> object:
|
||||
"""Create a message response or streaming response."""
|
||||
try:
|
||||
if not request_data.messages:
|
||||
raise InvalidRequestError("messages cannot be empty")
|
||||
_require_non_empty_messages(request_data.messages)
|
||||
|
||||
routed = self._model_router.resolve_messages_request(request_data)
|
||||
if is_web_server_tool_request(routed.request):
|
||||
if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
|
||||
tool_err = openai_chat_upstream_server_tool_error(
|
||||
routed.request,
|
||||
web_tools_enabled=self._settings.enable_web_server_tools,
|
||||
)
|
||||
if tool_err is not None:
|
||||
raise InvalidRequestError(tool_err)
|
||||
|
||||
if self._settings.enable_web_server_tools and is_web_server_tool_request(
|
||||
routed.request
|
||||
):
|
||||
input_tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
)
|
||||
logger.info("Optimization: Handling Anthropic web server tool")
|
||||
return StreamingResponse(
|
||||
egress = WebFetchEgressPolicy(
|
||||
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
|
||||
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
|
||||
)
|
||||
return anthropic_sse_streaming_response(
|
||||
stream_web_server_tool_response(
|
||||
routed.request, input_tokens=input_tokens
|
||||
routed.request,
|
||||
input_tokens=input_tokens,
|
||||
web_fetch_egress=egress,
|
||||
verbose_client_errors=self._settings.log_api_error_tracebacks,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
optimized = try_optimizations(routed.request, self._settings)
|
||||
|
|
@ -75,6 +138,10 @@ class ClaudeProxyService:
|
|||
logger.debug("No optimization matched, routing to provider")
|
||||
|
||||
provider = self._provider_getter(routed.resolved.provider_id)
|
||||
provider.preflight_stream(
|
||||
routed.request,
|
||||
thinking_enabled=routed.resolved.thinking_enabled,
|
||||
)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
logger.info(
|
||||
|
|
@ -83,34 +150,31 @@ class ClaudeProxyService:
|
|||
routed.request.model,
|
||||
len(routed.request.messages),
|
||||
)
|
||||
logger.debug(
|
||||
"FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
|
||||
)
|
||||
if self._settings.log_raw_api_payloads:
|
||||
logger.debug(
|
||||
"FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
|
||||
)
|
||||
|
||||
input_tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
)
|
||||
return StreamingResponse(
|
||||
return anthropic_sse_streaming_response(
|
||||
provider.stream_response(
|
||||
routed.request,
|
||||
input_tokens=input_tokens,
|
||||
request_id=request_id,
|
||||
thinking_enabled=routed.resolved.thinking_enabled,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
|
||||
_log_unexpected_service_exception(
|
||||
self._settings, e, context="CREATE_MESSAGE_ERROR"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=getattr(e, "status_code", 500),
|
||||
status_code=_http_status_for_unexpected_service_exception(e),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
|
||||
|
|
@ -119,6 +183,7 @@ class ClaudeProxyService:
|
|||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
with logger.contextualize(request_id=request_id):
|
||||
try:
|
||||
_require_non_empty_messages(request_data.messages)
|
||||
routed = self._model_router.resolve_token_count_request(request_data)
|
||||
tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
|
|
@ -131,13 +196,16 @@ class ClaudeProxyService:
|
|||
tokens,
|
||||
)
|
||||
return TokenCountResponse(input_tokens=tokens)
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
|
||||
request_id,
|
||||
get_user_facing_error_message(e),
|
||||
traceback.format_exc(),
|
||||
_log_unexpected_service_exception(
|
||||
self._settings,
|
||||
e,
|
||||
context="COUNT_TOKENS_ERROR",
|
||||
request_id=request_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=get_user_facing_error_message(e)
|
||||
status_code=_http_status_for_unexpected_service_exception(e),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
|
|
|
|||
48
api/validation_log.py
Normal file
48
api/validation_log.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def summarize_request_validation_body(
|
||||
body: Any,
|
||||
) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""Return message shape summary and tool name list for debug logs."""
|
||||
messages = body.get("messages") if isinstance(body, dict) else None
|
||||
message_summary: list[dict[str, Any]] = []
|
||||
if isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
message_summary.append({"message_kind": type(msg).__name__})
|
||||
continue
|
||||
content = msg.get("content")
|
||||
item: dict[str, Any] = {
|
||||
"role": msg.get("role"),
|
||||
"content_kind": type(content).__name__,
|
||||
}
|
||||
if isinstance(content, list):
|
||||
item["block_types"] = [
|
||||
block.get("type", "dict")
|
||||
if isinstance(block, dict)
|
||||
else type(block).__name__
|
||||
for block in content[:12]
|
||||
]
|
||||
item["block_keys"] = [
|
||||
sorted(str(key) for key in block)[:12]
|
||||
for block in content[:5]
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
elif isinstance(content, str):
|
||||
item["content_length"] = len(content)
|
||||
message_summary.append(item)
|
||||
|
||||
tool_names: list[str] = []
|
||||
if isinstance(body, dict) and isinstance(body.get("tools"), list):
|
||||
tool_names = [
|
||||
str(tool.get("name", ""))
|
||||
for tool in body["tools"]
|
||||
if isinstance(tool, dict)
|
||||
]
|
||||
|
||||
return message_summary, tool_names
|
||||
|
|
@ -1,331 +1,22 @@
|
|||
"""Local handlers for Anthropic web server tools.
|
||||
|
||||
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
|
||||
web tools are server-side: the API response itself must include the tool result.
|
||||
"""
|
||||
"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from .models.anthropic import MessagesRequest
|
||||
from api.web_tools.egress import (
|
||||
WebFetchEgressPolicy,
|
||||
WebFetchEgressViolation,
|
||||
enforce_web_fetch_egress,
|
||||
)
|
||||
from api.web_tools.request import is_web_server_tool_request
|
||||
from api.web_tools.streaming import stream_web_server_tool_response
|
||||
|
||||
_REQUEST_TIMEOUT_S = 20.0
|
||||
_MAX_SEARCH_RESULTS = 10
|
||||
_MAX_FETCH_CHARS = 24_000
|
||||
|
||||
|
||||
class _SearchResultParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results: list[dict[str, str]] = []
|
||||
self._href: str | None = None
|
||||
self._title_parts: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag != "a":
|
||||
return
|
||||
href = dict(attrs).get("href")
|
||||
if not href or "uddg=" not in href:
|
||||
return
|
||||
parsed = urlparse(href)
|
||||
query = parse_qs(parsed.query)
|
||||
uddg = query.get("uddg", [""])[0]
|
||||
if not uddg:
|
||||
return
|
||||
self._href = unquote(uddg)
|
||||
self._title_parts = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._href is not None:
|
||||
self._title_parts.append(data)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag != "a" or self._href is None:
|
||||
return
|
||||
title = " ".join("".join(self._title_parts).split())
|
||||
if title and not any(result["url"] == self._href for result in self.results):
|
||||
self.results.append({"title": html.unescape(title), "url": self._href})
|
||||
self._href = None
|
||||
self._title_parts = []
|
||||
|
||||
|
||||
class _HTMLTextParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.title = ""
|
||||
self.text_parts: list[str] = []
|
||||
self._in_title = False
|
||||
self._skip_depth = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag in {"script", "style", "noscript"}:
|
||||
self._skip_depth += 1
|
||||
elif tag == "title":
|
||||
self._in_title = True
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
||||
self._skip_depth -= 1
|
||||
elif tag == "title":
|
||||
self._in_title = False
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
text = " ".join(data.split())
|
||||
if not text:
|
||||
return
|
||||
if self._in_title:
|
||||
self.title = f"{self.title} {text}".strip()
|
||||
elif not self._skip_depth:
|
||||
self.text_parts.append(text)
|
||||
|
||||
|
||||
def _format_event(event_type: str, data: dict[str, Any]) -> str:
|
||||
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _content_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(str(item.get("text", "")))
|
||||
else:
|
||||
parts.append(str(getattr(item, "text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _request_text(request: MessagesRequest) -> str:
|
||||
return "\n".join(_content_text(message.content) for message in request.messages)
|
||||
|
||||
|
||||
def _web_tool_name(request: MessagesRequest) -> str | None:
|
||||
for tool in request.tools or []:
|
||||
name = tool.name
|
||||
tool_type = tool.type or ""
|
||||
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
||||
return _web_tool_name(request) in {"web_search", "web_fetch"}
|
||||
|
||||
|
||||
def _extract_query(text: str) -> str:
|
||||
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip().strip("\"'")
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _extract_url(text: str) -> str:
|
||||
match = re.search(r"https?://\S+", text)
|
||||
return match.group(0).rstrip(").,]") if match else text.strip()
|
||||
|
||||
|
||||
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"https://lite.duckduckgo.com/lite/",
|
||||
params={"q": query},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
parser = _SearchResultParser()
|
||||
parser.feed(response.text)
|
||||
return parser.results[:_MAX_SEARCH_RESULTS]
|
||||
|
||||
|
||||
async def _run_web_fetch(url: str) -> dict[str, str]:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "text/plain")
|
||||
title = url
|
||||
data = response.text
|
||||
if "html" in content_type.lower():
|
||||
parser = _HTMLTextParser()
|
||||
parser.feed(response.text)
|
||||
title = parser.title or url
|
||||
data = "\n".join(parser.text_parts)
|
||||
return {
|
||||
"url": str(response.url),
|
||||
"title": title,
|
||||
"media_type": "text/plain",
|
||||
"data": data[:_MAX_FETCH_CHARS],
|
||||
}
|
||||
|
||||
|
||||
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
||||
if not results:
|
||||
return f"No web search results found for: {query}"
|
||||
lines = [f"Search results for: {query}"]
|
||||
for index, result in enumerate(results, start=1):
|
||||
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def stream_web_server_tool_response(
|
||||
request: MessagesRequest, input_tokens: int
|
||||
) -> AsyncIterator[str]:
|
||||
tool_name = _web_tool_name(request)
|
||||
if tool_name is None:
|
||||
return
|
||||
|
||||
text = _request_text(request)
|
||||
message_id = f"msg_{uuid.uuid4()}"
|
||||
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
||||
output_tokens = 1
|
||||
usage_key = (
|
||||
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
||||
)
|
||||
tool_input = (
|
||||
{"query": _extract_query(text)}
|
||||
if tool_name == "web_search"
|
||||
else {"url": _extract_url(text)}
|
||||
)
|
||||
|
||||
yield _format_event(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": request.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "server_tool_use",
|
||||
"id": tool_id,
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "web_search":
|
||||
query = str(tool_input["query"])
|
||||
results = await _run_web_search(query)
|
||||
result_content: Any = [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
summary = _search_summary(query, results)
|
||||
result_block_type = "web_search_tool_result"
|
||||
else:
|
||||
fetched = await _run_web_fetch(str(tool_input["url"]))
|
||||
result_content = {
|
||||
"type": "web_fetch_result",
|
||||
"url": fetched["url"],
|
||||
"content": {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": fetched["media_type"],
|
||||
"data": fetched["data"],
|
||||
},
|
||||
"title": fetched["title"],
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
"retrieved_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
||||
result_block_type = "web_fetch_tool_result"
|
||||
except Exception as error:
|
||||
result_block_type = (
|
||||
"web_search_tool_result"
|
||||
if tool_name == "web_search"
|
||||
else "web_fetch_tool_result"
|
||||
)
|
||||
error_type = (
|
||||
"web_search_tool_result_error"
|
||||
if tool_name == "web_search"
|
||||
else "web_fetch_tool_error"
|
||||
)
|
||||
result_content = {"type": error_type, "error_code": "unavailable"}
|
||||
summary = f"{tool_name} failed: {type(error).__name__}"
|
||||
|
||||
output_tokens = max(1, len(summary) // 4)
|
||||
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": result_block_type,
|
||||
"tool_use_id": tool_id,
|
||||
"content": result_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": {"type": "text", "text": summary},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
||||
)
|
||||
yield _format_event(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"server_tool_use": {usage_key: 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event("message_stop", {"type": "message_stop"})
|
||||
__all__ = [
|
||||
"WebFetchEgressPolicy",
|
||||
"WebFetchEgressViolation",
|
||||
"enforce_web_fetch_egress",
|
||||
"httpx",
|
||||
"is_web_server_tool_request",
|
||||
"stream_web_server_tool_response",
|
||||
]
|
||||
|
|
|
|||
17
api/web_tools/__init__.py
Normal file
17
api/web_tools/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
|
||||
|
||||
from .egress import (
|
||||
WebFetchEgressPolicy,
|
||||
WebFetchEgressViolation,
|
||||
enforce_web_fetch_egress,
|
||||
)
|
||||
from .request import is_web_server_tool_request
|
||||
from .streaming import stream_web_server_tool_response
|
||||
|
||||
__all__ = [
|
||||
"WebFetchEgressPolicy",
|
||||
"WebFetchEgressViolation",
|
||||
"enforce_web_fetch_egress",
|
||||
"is_web_server_tool_request",
|
||||
"stream_web_server_tool_response",
|
||||
]
|
||||
15
api/web_tools/constants.py
Normal file
15
api/web_tools/constants.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Limits and defaults for outbound web server tool HTTP."""
|
||||
|
||||
_REQUEST_TIMEOUT_S = 20.0
|
||||
_MAX_SEARCH_RESULTS = 10
|
||||
_MAX_FETCH_CHARS = 24_000
|
||||
# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
|
||||
_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
|
||||
# Drain at most this many bytes from redirect responses before following Location.
|
||||
_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
|
||||
_MAX_WEB_FETCH_REDIRECTS = 10
|
||||
_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
|
||||
|
||||
_WEB_TOOL_HTTP_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
|
||||
}
|
||||
99
api/web_tools/egress.py
Normal file
99
api/web_tools/egress.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WebFetchEgressPolicy:
|
||||
"""Egress rules for user-influenced web_fetch URLs."""
|
||||
|
||||
allow_private_network_targets: bool
|
||||
allowed_schemes: frozenset[str]
|
||||
|
||||
|
||||
class WebFetchEgressViolation(ValueError):
|
||||
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
|
||||
|
||||
|
||||
def _port_for_url(parsed) -> int:
|
||||
if parsed.port is not None:
|
||||
return parsed.port
|
||||
return 443 if (parsed.scheme or "").lower() == "https" else 80
|
||||
|
||||
|
||||
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
|
||||
try:
|
||||
return socket.getaddrinfo(
|
||||
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
|
||||
)
|
||||
except OSError as exc:
|
||||
raise WebFetchEgressViolation(
|
||||
f"Could not resolve host {host!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
|
||||
def get_validated_stream_addrinfos_for_egress(
|
||||
url: str, policy: WebFetchEgressPolicy
|
||||
) -> list[tuple]:
|
||||
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
|
||||
|
||||
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
|
||||
server cannot rebind to a disallowed address between resolution and the TCP
|
||||
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
scheme = (parsed.scheme or "").lower()
|
||||
if scheme not in policy.allowed_schemes:
|
||||
raise WebFetchEgressViolation(
|
||||
f"URL scheme {scheme!r} is not allowed for web_fetch"
|
||||
)
|
||||
|
||||
host = parsed.hostname
|
||||
if host is None or host == "":
|
||||
raise WebFetchEgressViolation("web_fetch URL must include a host")
|
||||
|
||||
port = _port_for_url(parsed)
|
||||
|
||||
if policy.allow_private_network_targets:
|
||||
return _stream_getaddrinfo_or_raise(host, port)
|
||||
|
||||
host_lower = host.lower()
|
||||
if host_lower == "localhost" or host_lower.endswith(".localhost"):
|
||||
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
|
||||
if host_lower.endswith(".local"):
|
||||
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
|
||||
|
||||
try:
|
||||
parsed_ip = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
parsed_ip = None
|
||||
|
||||
if parsed_ip is not None:
|
||||
if not parsed_ip.is_global:
|
||||
raise WebFetchEgressViolation(
|
||||
f"Non-public IP host {host!r} is not allowed for web_fetch"
|
||||
)
|
||||
return _stream_getaddrinfo_or_raise(host, port)
|
||||
|
||||
infos = _stream_getaddrinfo_or_raise(host, port)
|
||||
for *_, sockaddr in infos:
|
||||
addr = sockaddr[0]
|
||||
try:
|
||||
resolved = ipaddress.ip_address(addr)
|
||||
except ValueError:
|
||||
continue
|
||||
if not resolved.is_global:
|
||||
raise WebFetchEgressViolation(
|
||||
f"Host {host!r} resolves to a non-public address ({resolved})"
|
||||
)
|
||||
return infos
|
||||
|
||||
|
||||
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
|
||||
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
|
||||
get_validated_stream_addrinfos_for_egress(url, policy)
|
||||
278
api/web_tools/outbound.py
Normal file
278
api/web_tools/outbound.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from collections.abc import AsyncIterator
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
from aiohttp.abc import AbstractResolver, ResolveResult
|
||||
from loguru import logger
|
||||
|
||||
from . import constants
|
||||
from .constants import (
|
||||
_MAX_FETCH_CHARS,
|
||||
_MAX_SEARCH_RESULTS,
|
||||
_REDIRECT_RESPONSE_BODY_CAP_BYTES,
|
||||
_REQUEST_TIMEOUT_S,
|
||||
_WEB_FETCH_REDIRECT_STATUSES,
|
||||
_WEB_TOOL_HTTP_HEADERS,
|
||||
)
|
||||
from .egress import (
|
||||
WebFetchEgressPolicy,
|
||||
WebFetchEgressViolation,
|
||||
get_validated_stream_addrinfos_for_egress,
|
||||
)
|
||||
from .parsers import HTMLTextParser, SearchResultParser
|
||||
|
||||
|
||||
def _safe_public_host_for_logs(url: str) -> str:
|
||||
host = urlparse(url).hostname or ""
|
||||
return host[:253]
|
||||
|
||||
|
||||
def _log_web_tool_failure(
|
||||
tool_name: str,
|
||||
error: BaseException,
|
||||
*,
|
||||
fetch_url: str | None = None,
|
||||
) -> None:
|
||||
exc_type = type(error).__name__
|
||||
if isinstance(error, WebFetchEgressViolation):
|
||||
host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
|
||||
logger.warning(
|
||||
"web_tool_egress_rejected tool={} exc_type={} host={!r}",
|
||||
tool_name,
|
||||
exc_type,
|
||||
host,
|
||||
)
|
||||
return
|
||||
if tool_name == "web_fetch" and fetch_url:
|
||||
logger.warning(
|
||||
"web_tool_failure tool={} exc_type={} host={!r}",
|
||||
tool_name,
|
||||
exc_type,
|
||||
_safe_public_host_for_logs(fetch_url),
|
||||
)
|
||||
else:
|
||||
logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
|
||||
|
||||
|
||||
def _web_tool_client_error_summary(
|
||||
tool_name: str,
|
||||
error: BaseException,
|
||||
*,
|
||||
verbose: bool,
|
||||
) -> str:
|
||||
if verbose:
|
||||
return f"{tool_name} failed: {type(error).__name__}"
|
||||
return "Web tool request failed."
|
||||
|
||||
|
||||
async def _iter_response_body_under_cap(
|
||||
response: httpx.Response, max_bytes: int
|
||||
) -> AsyncIterator[bytes]:
|
||||
if max_bytes <= 0:
|
||||
return
|
||||
received = 0
|
||||
async for chunk in response.aiter_bytes(chunk_size=65_536):
|
||||
if received >= max_bytes:
|
||||
break
|
||||
remaining = max_bytes - received
|
||||
if len(chunk) <= remaining:
|
||||
received += len(chunk)
|
||||
yield chunk
|
||||
if received >= max_bytes:
|
||||
break
|
||||
else:
|
||||
yield chunk[:remaining]
|
||||
break
|
||||
|
||||
|
||||
async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
|
||||
async for _ in _iter_response_body_under_cap(response, max_bytes):
|
||||
pass
|
||||
|
||||
|
||||
async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
|
||||
return b"".join(
|
||||
[piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
|
||||
)
|
||||
|
||||
|
||||
_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
|
||||
_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
|
||||
|
||||
|
||||
def getaddrinfo_rows_to_resolve_results(
|
||||
host: str, addrinfos: list[tuple]
|
||||
) -> list[ResolveResult]:
|
||||
"""Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
|
||||
out: list[ResolveResult] = []
|
||||
for family, _type, proto, _canon, sockaddr in addrinfos:
|
||||
if family == socket.AF_INET6:
|
||||
if len(sockaddr) < 3:
|
||||
continue
|
||||
if sockaddr[3]:
|
||||
resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
|
||||
else:
|
||||
resolved_host, port = sockaddr[:2]
|
||||
else:
|
||||
assert family == socket.AF_INET, family
|
||||
resolved_host, port = sockaddr[0], sockaddr[1]
|
||||
resolved_host = str(resolved_host)
|
||||
port = int(port)
|
||||
out.append(
|
||||
ResolveResult(
|
||||
hostname=host,
|
||||
host=resolved_host,
|
||||
port=int(port),
|
||||
family=family,
|
||||
proto=proto,
|
||||
flags=_NUMERIC_RESOLVE_FLAGS,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class _PinnedEgressStaticResolver(AbstractResolver):
|
||||
"""Return only pre-validated :class:`ResolveResult` for the outbound request."""
|
||||
|
||||
def __init__(self, results: list[ResolveResult]) -> None:
|
||||
self._results = results
|
||||
|
||||
async def resolve(
|
||||
self, host: str, port: int = 0, family: int = socket.AF_INET
|
||||
) -> list[ResolveResult]:
|
||||
return self._results
|
||||
|
||||
async def close(self) -> None: # pragma: no cover - aiohttp contract
|
||||
return
|
||||
|
||||
|
||||
async def _read_aiohttp_body_capped(
|
||||
response: aiohttp.ClientResponse, max_bytes: int
|
||||
) -> bytes:
|
||||
received = 0
|
||||
parts: list[bytes] = []
|
||||
async for chunk in response.content.iter_chunked(65_536):
|
||||
if received >= max_bytes:
|
||||
break
|
||||
remaining = max_bytes - received
|
||||
if len(chunk) <= remaining:
|
||||
received += len(chunk)
|
||||
parts.append(chunk)
|
||||
else:
|
||||
parts.append(chunk[:remaining])
|
||||
break
|
||||
return b"".join(parts)
|
||||
|
||||
|
||||
async def _drain_aiohttp_body_capped(
|
||||
response: aiohttp.ClientResponse, max_bytes: int
|
||||
) -> None:
|
||||
if max_bytes <= 0:
|
||||
return
|
||||
received = 0
|
||||
async for chunk in response.content.iter_chunked(65_536):
|
||||
received += len(chunk)
|
||||
if received >= max_bytes:
|
||||
break
|
||||
|
||||
|
||||
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
||||
async with (
|
||||
httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers=_WEB_TOOL_HTTP_HEADERS,
|
||||
) as client,
|
||||
client.stream(
|
||||
"GET",
|
||||
"https://lite.duckduckgo.com/lite/",
|
||||
params={"q": query},
|
||||
) as response,
|
||||
):
|
||||
response.raise_for_status()
|
||||
body_bytes = await _read_response_body_capped(
|
||||
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
|
||||
)
|
||||
text = body_bytes.decode("utf-8", errors="replace")
|
||||
parser = SearchResultParser()
|
||||
parser.feed(text)
|
||||
return parser.results[:_MAX_SEARCH_RESULTS]
|
||||
|
||||
|
||||
async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
|
||||
"""Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
|
||||
current_url = url
|
||||
redirect_hops = 0
|
||||
timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
|
||||
|
||||
while True:
|
||||
addr_infos = await asyncio.to_thread(
|
||||
get_validated_stream_addrinfos_for_egress, current_url, egress
|
||||
)
|
||||
host = urlparse(current_url).hostname or ""
|
||||
results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
|
||||
resolver = _PinnedEgressStaticResolver(results)
|
||||
connector = TCPConnector(
|
||||
resolver=resolver,
|
||||
force_close=True,
|
||||
)
|
||||
try:
|
||||
async with (
|
||||
ClientSession(
|
||||
timeout=timeout,
|
||||
headers=_WEB_TOOL_HTTP_HEADERS,
|
||||
connector=connector,
|
||||
) as session,
|
||||
session.get(current_url, allow_redirects=False) as response,
|
||||
):
|
||||
if response.status in _WEB_FETCH_REDIRECT_STATUSES:
|
||||
await _drain_aiohttp_body_capped(
|
||||
response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
|
||||
)
|
||||
if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
|
||||
raise WebFetchEgressViolation(
|
||||
"web_fetch exceeded maximum redirects "
|
||||
f"({constants._MAX_WEB_FETCH_REDIRECTS})"
|
||||
)
|
||||
location = response.headers.get("location")
|
||||
if not location or not location.strip():
|
||||
raise WebFetchEgressViolation(
|
||||
"web_fetch redirect response missing Location header"
|
||||
)
|
||||
current_url = urljoin(str(response.url), location.strip())
|
||||
redirect_hops += 1
|
||||
continue
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "text/plain")
|
||||
final_url = str(response.url)
|
||||
encoding = response.get_encoding() or "utf-8"
|
||||
body_bytes = await _read_aiohttp_body_capped(
|
||||
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
|
||||
)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
break
|
||||
|
||||
text = body_bytes.decode(encoding, errors="replace")
|
||||
title = final_url
|
||||
data = text
|
||||
if "html" in content_type.lower():
|
||||
parser = HTMLTextParser()
|
||||
parser.feed(text)
|
||||
title = parser.title or final_url
|
||||
data = "\n".join(parser.text_parts)
|
||||
return {
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"media_type": "text/plain",
|
||||
"data": data[:_MAX_FETCH_CHARS],
|
||||
}
|
||||
104
api/web_tools/parsers.py
Normal file
104
api/web_tools/parsers.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""HTML parsing for web_search / web_fetch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import re
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
|
||||
class SearchResultParser(HTMLParser):
|
||||
"""DuckDuckGo lite HTML: extract result links and titles."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results: list[dict[str, str]] = []
|
||||
self._href: str | None = None
|
||||
self._title_parts: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag != "a":
|
||||
return
|
||||
href = dict(attrs).get("href")
|
||||
if not href or "uddg=" not in href:
|
||||
return
|
||||
parsed = urlparse(href)
|
||||
query = parse_qs(parsed.query)
|
||||
uddg = query.get("uddg", [""])[0]
|
||||
if not uddg:
|
||||
return
|
||||
self._href = unquote(uddg)
|
||||
self._title_parts = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._href is not None:
|
||||
self._title_parts.append(data)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag != "a" or self._href is None:
|
||||
return
|
||||
title = " ".join("".join(self._title_parts).split())
|
||||
if title and not any(result["url"] == self._href for result in self.results):
|
||||
self.results.append({"title": html.unescape(title), "url": self._href})
|
||||
self._href = None
|
||||
self._title_parts = []
|
||||
|
||||
|
||||
class HTMLTextParser(HTMLParser):
|
||||
"""Strip scripts/styles and collect visible text + title for fetch previews."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.title = ""
|
||||
self.text_parts: list[str] = []
|
||||
self._in_title = False
|
||||
self._skip_depth = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag in {"script", "style", "noscript"}:
|
||||
self._skip_depth += 1
|
||||
elif tag == "title":
|
||||
self._in_title = True
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
||||
self._skip_depth -= 1
|
||||
elif tag == "title":
|
||||
self._in_title = False
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
text = " ".join(data.split())
|
||||
if not text:
|
||||
return
|
||||
if self._in_title:
|
||||
self.title = f"{self.title} {text}".strip()
|
||||
elif not self._skip_depth:
|
||||
self.text_parts.append(text)
|
||||
|
||||
|
||||
def content_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(str(item.get("text", "")))
|
||||
else:
|
||||
parts.append(str(getattr(item, "text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def extract_query(text: str) -> str:
|
||||
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip().strip("\"'")
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_url(text: str) -> str:
|
||||
match = re.search(r"https?://\S+", text)
|
||||
return match.group(0).rstrip(").,]") if match else text.strip()
|
||||
87
api/web_tools/request.py
Normal file
87
api/web_tools/request.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Detect forced Anthropic web server tool requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from api.models.anthropic import MessagesRequest, Tool
|
||||
|
||||
|
||||
def request_text(request: MessagesRequest) -> str:
|
||||
"""Join all user/assistant message content into one string for tool input parsing."""
|
||||
from .parsers import content_text
|
||||
|
||||
return "\n".join(content_text(message.content) for message in request.messages)
|
||||
|
||||
|
||||
def forced_tool_turn_text(request: MessagesRequest) -> str:
|
||||
"""Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
|
||||
if not request.messages:
|
||||
return ""
|
||||
|
||||
from .parsers import content_text
|
||||
|
||||
for message in reversed(request.messages):
|
||||
if message.role == "user":
|
||||
return content_text(message.content)
|
||||
return ""
|
||||
|
||||
|
||||
def forced_server_tool_name(request: MessagesRequest) -> str | None:
|
||||
"""Return web_search or web_fetch only when tool_choice forces that server tool."""
|
||||
tc = request.tool_choice
|
||||
if not isinstance(tc, dict):
|
||||
return None
|
||||
if tc.get("type") != "tool":
|
||||
return None
|
||||
name = tc.get("name")
|
||||
if name in {"web_search", "web_fetch"}:
|
||||
return str(name)
|
||||
return None
|
||||
|
||||
|
||||
def has_tool_named(request: MessagesRequest, name: str) -> bool:
|
||||
return any(tool.name == name for tool in request.tools or [])
|
||||
|
||||
|
||||
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
||||
"""True when the client forces a web server tool via tool_choice (not merely listed)."""
|
||||
forced = forced_server_tool_name(request)
|
||||
if forced is None:
|
||||
return False
|
||||
return has_tool_named(request, forced)
|
||||
|
||||
|
||||
def is_anthropic_server_tool_definition(tool: Tool) -> bool:
|
||||
"""Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
|
||||
name = (tool.name or "").strip()
|
||||
if name in ("web_search", "web_fetch"):
|
||||
return True
|
||||
typ = tool.type
|
||||
if isinstance(typ, str):
|
||||
return typ.startswith("web_search") or typ.startswith("web_fetch")
|
||||
return False
|
||||
|
||||
|
||||
def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
|
||||
"""True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
|
||||
return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
|
||||
|
||||
|
||||
def openai_chat_upstream_server_tool_error(
|
||||
request: MessagesRequest, *, web_tools_enabled: bool
|
||||
) -> str | None:
|
||||
"""Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
|
||||
forced = forced_server_tool_name(request)
|
||||
if forced and not web_tools_enabled:
|
||||
return (
|
||||
f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
|
||||
"disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic "
|
||||
"Messages transport (e.g. open_router, ollama, lmstudio)."
|
||||
)
|
||||
if not forced and has_listed_anthropic_server_tools(request):
|
||||
return (
|
||||
"OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools "
|
||||
"(web_search / web_fetch) without the local web server tool handler. Use a native "
|
||||
"Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
|
||||
"tool_choice, or remove these tools from the request."
|
||||
)
|
||||
return None
|
||||
206
api/web_tools/streaming.py
Normal file
206
api/web_tools/streaming.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""SSE streaming for local web_search / web_fetch server tool results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from api.models.anthropic import MessagesRequest
|
||||
from core.anthropic.server_tool_sse import (
|
||||
SERVER_TOOL_USE,
|
||||
WEB_FETCH_TOOL_ERROR,
|
||||
WEB_FETCH_TOOL_RESULT,
|
||||
WEB_SEARCH_TOOL_RESULT,
|
||||
WEB_SEARCH_TOOL_RESULT_ERROR,
|
||||
)
|
||||
from core.anthropic.sse import format_sse_event
|
||||
|
||||
from . import outbound
|
||||
from .constants import _MAX_FETCH_CHARS
|
||||
from .egress import WebFetchEgressPolicy
|
||||
from .parsers import extract_query, extract_url
|
||||
from .request import (
|
||||
forced_server_tool_name,
|
||||
forced_tool_turn_text,
|
||||
has_tool_named,
|
||||
)
|
||||
|
||||
|
||||
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
||||
if not results:
|
||||
return f"No web search results found for: {query}"
|
||||
lines = [f"Search results for: {query}"]
|
||||
for index, result in enumerate(results, start=1):
|
||||
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def stream_web_server_tool_response(
|
||||
request: MessagesRequest,
|
||||
input_tokens: int,
|
||||
*,
|
||||
web_fetch_egress: WebFetchEgressPolicy,
|
||||
verbose_client_errors: bool = False,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
|
||||
|
||||
When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full
|
||||
hosted Anthropic citation or encrypted-content pipeline.
|
||||
"""
|
||||
tool_name = forced_server_tool_name(request)
|
||||
if tool_name is None or not has_tool_named(request, tool_name):
|
||||
return
|
||||
|
||||
text = forced_tool_turn_text(request)
|
||||
message_id = f"msg_{uuid.uuid4()}"
|
||||
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
||||
usage_key = (
|
||||
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
||||
)
|
||||
tool_input = (
|
||||
{"query": extract_query(text)}
|
||||
if tool_name == "web_search"
|
||||
else {"url": extract_url(text)}
|
||||
)
|
||||
_result_block_for_tool = {
|
||||
"web_search": WEB_SEARCH_TOOL_RESULT,
|
||||
"web_fetch": WEB_FETCH_TOOL_RESULT,
|
||||
}
|
||||
_error_payload_type_for_tool = {
|
||||
"web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
|
||||
"web_fetch": WEB_FETCH_TOOL_ERROR,
|
||||
}
|
||||
|
||||
yield format_sse_event(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": request.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": SERVER_TOOL_USE,
|
||||
"id": tool_id,
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "web_search":
|
||||
query = str(tool_input["query"])
|
||||
results = await outbound._run_web_search(query)
|
||||
result_content: Any = [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
summary = _search_summary(query, results)
|
||||
result_block_type = WEB_SEARCH_TOOL_RESULT
|
||||
else:
|
||||
fetched = await outbound._run_web_fetch(
|
||||
str(tool_input["url"]), web_fetch_egress
|
||||
)
|
||||
result_content = {
|
||||
"type": "web_fetch_result",
|
||||
"url": fetched["url"],
|
||||
"content": {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": fetched["media_type"],
|
||||
"data": fetched["data"],
|
||||
},
|
||||
"title": fetched["title"],
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
"retrieved_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
||||
result_block_type = WEB_FETCH_TOOL_RESULT
|
||||
except Exception as error:
|
||||
fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
|
||||
outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
|
||||
result_block_type = _result_block_for_tool[tool_name]
|
||||
result_content = {
|
||||
"type": _error_payload_type_for_tool[tool_name],
|
||||
"error_code": "unavailable",
|
||||
}
|
||||
summary = outbound._web_tool_client_error_summary(
|
||||
tool_name, error, verbose=verbose_client_errors
|
||||
)
|
||||
|
||||
output_tokens = max(1, len(summary) // 4)
|
||||
|
||||
yield format_sse_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": result_block_type,
|
||||
"tool_use_id": tool_id,
|
||||
"content": result_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
||||
)
|
||||
# Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
|
||||
# not eager `text` on `content_block_start`).
|
||||
yield format_sse_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 2,
|
||||
"delta": {"type": "text_delta", "text": summary},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
||||
)
|
||||
yield format_sse_event(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"server_tool_use": {usage_key: 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event("message_stop", {"type": "message_stop"})
|
||||
Loading…
Add table
Add a link
Reference in a new issue