mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-26 10:31:07 +00:00
Major refactor: API, providers, messaging, and Anthropic protocol
Some checks are pending
CI / checks (push) Waiting to run
Some checks are pending
CI / checks (push) Waiting to run
Consolidates the incremental refactor work into a single change set: modular web tools (api/web_tools), native Anthropic request building and SSE block policy, OpenAI conversion and error handling, provider transports and rate limiting, messaging handler and tree queue, safe logging, smoke tests, and broad test coverage.
This commit is contained in:
parent
b9ed704095
commit
f3a7528d49
139 changed files with 7460 additions and 2422 deletions
26
.env.example
26
.env.example
|
|
@ -47,8 +47,8 @@ OPENROUTER_PROXY=""
|
|||
LMSTUDIO_PROXY=""
|
||||
LLAMACPP_PROXY=""
|
||||
|
||||
PROVIDER_RATE_LIMIT=40
|
||||
PROVIDER_RATE_WINDOW=60
|
||||
PROVIDER_RATE_LIMIT=1
|
||||
PROVIDER_RATE_WINDOW=3
|
||||
PROVIDER_MAX_CONCURRENCY=5
|
||||
|
||||
|
||||
|
|
@ -102,3 +102,25 @@ ENABLE_NETWORK_PROBE_MOCK=true
|
|||
ENABLE_TITLE_GENERATION_SKIP=true
|
||||
ENABLE_SUGGESTION_MODE_SKIP=true
|
||||
ENABLE_FILEPATH_EXTRACTION_MOCK=true
|
||||
|
||||
|
||||
# Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default)
|
||||
ENABLE_WEB_SERVER_TOOLS=true
|
||||
WEB_FETCH_ALLOWED_SCHEMES=http,https
|
||||
WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false
|
||||
|
||||
|
||||
# Verbose diagnostics (avoid logging raw prompts / SSE bodies in production)
|
||||
DEBUG_PLATFORM_EDITS=false
|
||||
DEBUG_SUBAGENT_STACK=false
|
||||
# When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging).
|
||||
LOG_RAW_API_PAYLOADS=false
|
||||
LOG_RAW_SSE_EVENTS=false
|
||||
# When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data).
|
||||
LOG_API_ERROR_TRACEBACKS=false
|
||||
# When true, log message/transcription text previews in messaging adapters (may leak user content).
|
||||
LOG_RAW_MESSAGING_CONTENT=false
|
||||
# When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text.
|
||||
LOG_RAW_CLI_DIAGNOSTICS=false
|
||||
# When true, log full exception and CLI error message strings in messaging (may leak user content).
|
||||
LOG_MESSAGING_ERROR_DETAILS=false
|
||||
|
|
|
|||
|
|
@ -32,13 +32,14 @@
|
|||
- **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code.
|
||||
- **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue.
|
||||
- **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required.
|
||||
- **Maximum Test Coverage**: There should be maximum test coverage for everything, preferably live smoke test coverage to catch bugs early
|
||||
|
||||
## COGNITIVE WORKFLOW
|
||||
|
||||
1. **ANALYZE**: Read relevant files. Do not guess.
|
||||
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
|
||||
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
|
||||
4. **VERIFY**: Run ci checks. Confirm the fix via logs or output.
|
||||
4. **VERIFY**: Run ci checks and relevant smoke tests. Confirm the fix via logs or output.
|
||||
5. **SPECIFICITY**: Do exactly as much as asked; nothing more, nothing less.
|
||||
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.
|
||||
|
||||
|
|
|
|||
19
PLAN.md
19
PLAN.md
|
|
@ -56,13 +56,20 @@ with **runtime composition** (e.g. `api.runtime` constructs `cli` and `messaging
|
|||
**Contract highlights:** `api/` may import only `providers.base`, `providers.exceptions`,
|
||||
and `providers.registry` from the providers package (not per-adapter modules).
|
||||
`core/` stays free of `api`, `messaging`, `cli`, `providers`, `config`, and `smoke`.
|
||||
`messaging/` does not import `api`, `cli`, or `providers`. Neutral stream contract
|
||||
assertions for default CI live under `core/anthropic/stream_contracts.py`;
|
||||
`smoke.lib.sse` re-exports them for live smoke. Process-cached provider helpers
|
||||
(`api.dependencies.get_provider` / `get_provider_for_type`) exist for scripts and
|
||||
unit tests; production HTTP handlers must use `resolve_provider` with
|
||||
`messaging/` does not import `api`, `cli`, or `smoke`, and may import `providers`
|
||||
only via `providers.nvidia_nim.voice` (NVIDIA/Riva offline ASR). Stream contract
|
||||
helpers live in `core/anthropic/stream_contracts.py`; live smoke imports that
|
||||
module directly (no dedicated smoke SSE shim). NVIDIA NIM chat tuning uses the
|
||||
canonical `config.nim.NimSettings` model on `Settings`; `providers.registry`
|
||||
passes `settings.nim` into `NvidiaNimProvider` without a duplicate schema.
|
||||
Default upstream base URLs use a single constant per endpoint in
|
||||
`providers/defaults.py` (e.g. `NVIDIA_NIM_DEFAULT_BASE`). Process-cached provider
|
||||
helpers (`api.dependencies.get_provider` / `get_provider_for_type`) exist for
|
||||
scripts and unit tests; production HTTP handlers must use `resolve_provider` with
|
||||
`request.app` so the app-scoped `ProviderRegistry` is used. The `api` package
|
||||
`__all__` exposes HTTP models and `create_app` only (not those helpers).
|
||||
`__all__` exposes HTTP models and `create_app` only (not `app`, not those helpers).
|
||||
`api.app:create_app` is the ASGI factory (e.g. `uvicorn api.app:create_app --factory`);
|
||||
`server.py` still exposes `server:app` as a module-level instance for convenience.
|
||||
|
||||
## Target Boundaries
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""API layer for Claude Code Proxy."""
|
||||
|
||||
from .app import app, create_app
|
||||
from .app import create_app
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
|
|
@ -13,6 +13,5 @@ __all__ = [
|
|||
"MessagesResponse",
|
||||
"TokenCountRequest",
|
||||
"TokenCountResponse",
|
||||
"app",
|
||||
"create_app",
|
||||
]
|
||||
|
|
|
|||
83
api/app.py
83
api/app.py
|
|
@ -1,6 +1,6 @@
|
|||
"""FastAPI application factory and configuration."""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -16,13 +16,7 @@ from providers.exceptions import ProviderError
|
|||
|
||||
from .routes import router
|
||||
from .runtime import AppRuntime
|
||||
|
||||
# Opt-in to future behavior for python-telegram-bot
|
||||
os.environ["PTB_TIMEDELTA"] = "1"
|
||||
|
||||
# Configure logging first (before any module logs)
|
||||
_settings = get_settings()
|
||||
configure_logging(_settings.log_file)
|
||||
from .validation_log import summarize_request_validation_body
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -38,6 +32,11 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
settings = get_settings()
|
||||
configure_logging(
|
||||
settings.log_file, verbose_third_party=settings.log_raw_api_payloads
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="Claude Code Proxy",
|
||||
version="2.0.0",
|
||||
|
|
@ -57,33 +56,7 @@ def create_app() -> FastAPI:
|
|||
except Exception as e:
|
||||
body = {"_json_error": type(e).__name__}
|
||||
|
||||
messages = body.get("messages") if isinstance(body, dict) else None
|
||||
message_summary: list[dict[str, Any]] = []
|
||||
if isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
message_summary.append({"message_kind": type(msg).__name__})
|
||||
continue
|
||||
content = msg.get("content")
|
||||
item: dict[str, Any] = {
|
||||
"role": msg.get("role"),
|
||||
"content_kind": type(content).__name__,
|
||||
}
|
||||
if isinstance(content, list):
|
||||
item["block_types"] = [
|
||||
block.get("type", "dict")
|
||||
if isinstance(block, dict)
|
||||
else type(block).__name__
|
||||
for block in content[:12]
|
||||
]
|
||||
item["block_keys"] = [
|
||||
sorted(str(key) for key in block)[:12]
|
||||
for block in content[:5]
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
elif isinstance(content, str):
|
||||
item["content_length"] = len(content)
|
||||
message_summary.append(item)
|
||||
message_summary, tool_names = summarize_request_validation_body(body)
|
||||
|
||||
logger.debug(
|
||||
"Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
|
||||
|
|
@ -92,20 +65,27 @@ def create_app() -> FastAPI:
|
|||
[list(error.get("loc", ())) for error in exc.errors()],
|
||||
[str(error.get("type", "")) for error in exc.errors()],
|
||||
message_summary,
|
||||
[
|
||||
str(tool.get("name", ""))
|
||||
for tool in body.get("tools", [])
|
||||
if isinstance(body, dict)
|
||||
and isinstance(body.get("tools"), list)
|
||||
and isinstance(tool, dict)
|
||||
],
|
||||
tool_names,
|
||||
)
|
||||
return await request_validation_exception_handler(request, exc)
|
||||
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
"""Handle provider-specific errors and return Anthropic format."""
|
||||
logger.error(f"Provider Error: {exc.error_type} - {exc.message}")
|
||||
err_settings = get_settings()
|
||||
if err_settings.log_api_error_tracebacks:
|
||||
logger.error(
|
||||
"Provider Error: error_type={} status_code={} message={}",
|
||||
exc.error_type,
|
||||
exc.status_code,
|
||||
exc.message,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Provider Error: error_type={} status_code={}",
|
||||
exc.error_type,
|
||||
exc.status_code,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=exc.to_anthropic_format(),
|
||||
|
|
@ -114,10 +94,17 @@ def create_app() -> FastAPI:
|
|||
@app.exception_handler(Exception)
|
||||
async def general_error_handler(request: Request, exc: Exception):
|
||||
"""Handle general errors and return Anthropic format."""
|
||||
logger.error(f"General Error: {exc!s}")
|
||||
import traceback
|
||||
|
||||
settings = get_settings()
|
||||
if settings.log_api_error_tracebacks:
|
||||
logger.error("General Error: {}", exc)
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.error(
|
||||
"General Error: path={} method={} exc_type={}",
|
||||
request.url.path,
|
||||
request.method,
|
||||
type(exc).__name__,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
|
|
@ -130,7 +117,3 @@ def create_app() -> FastAPI:
|
|||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Default app instance for uvicorn
|
||||
app = create_app()
|
||||
|
|
|
|||
|
|
@ -135,5 +135,5 @@ def extract_filepaths_from_command(command: str, output: str) -> str:
|
|||
|
||||
return "<filepaths>\n</filepaths>"
|
||||
|
||||
except Exception:
|
||||
except ValueError:
|
||||
return "<filepaths>\n</filepaths>"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from config.settings import Settings
|
|||
from config.settings import get_settings as _get_settings
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from providers.base import BaseProvider
|
||||
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||
from providers.exceptions import (
|
||||
AuthenticationError,
|
||||
ServiceUnavailableError,
|
||||
UnknownProviderTypeError,
|
||||
)
|
||||
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
||||
|
||||
# Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
|
||||
|
|
@ -18,7 +22,7 @@ _providers: dict[str, BaseProvider] = {}
|
|||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings via dependency injection."""
|
||||
"""Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
|
||||
return _get_settings()
|
||||
|
||||
|
||||
|
|
@ -31,10 +35,9 @@ def resolve_provider(
|
|||
"""Resolve a provider using the app-scoped registry when ``app`` is set.
|
||||
|
||||
When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
|
||||
is always used. If the registry is missing (e.g. a test app without
|
||||
:class:`~api.runtime.AppRuntime` startup), a new :class:`ProviderRegistry`
|
||||
is installed on ``app.state`` so the process cache is never mixed with
|
||||
per-request app identity.
|
||||
must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
|
||||
Callers that construct a bare ``FastAPI`` without lifespan must set
|
||||
``app.state.provider_registry`` explicitly.
|
||||
|
||||
When ``app`` is ``None`` (no HTTP context), uses the process-level
|
||||
:data:`_providers` cache only.
|
||||
|
|
@ -42,8 +45,10 @@ def resolve_provider(
|
|||
if app is not None:
|
||||
reg = getattr(app.state, "provider_registry", None)
|
||||
if reg is None:
|
||||
reg = ProviderRegistry()
|
||||
app.state.provider_registry = reg
|
||||
raise ServiceUnavailableError(
|
||||
"Provider registry is not configured. Ensure AppRuntime startup ran "
|
||||
"or assign app.state.provider_registry for test apps."
|
||||
)
|
||||
return _resolve_with_registry(reg, provider_type, settings)
|
||||
return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
|
||||
|
||||
|
|
@ -55,9 +60,10 @@ def _resolve_with_registry(
|
|||
try:
|
||||
provider = registry.get(provider_type, settings)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
# Provider :class:`~providers.exceptions.AuthenticationError` messages are
|
||||
# curated configuration hints (env var names, docs links), not upstream noise.
|
||||
detail = str(e).strip() or get_user_facing_error_message(e)
|
||||
raise HTTPException(status_code=503, detail=detail) from e
|
||||
except UnknownProviderTypeError:
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: {}",
|
||||
|
|
@ -73,8 +79,9 @@ def _resolve_with_registry(
|
|||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider in the process-level cache (no ``app``/Request).
|
||||
|
||||
For server requests, use :func:`resolve_provider` with the active
|
||||
:attr:`request.app` so the app-scoped provider registry is used.
|
||||
HTTP route handlers should call :func:`resolve_provider` with the active
|
||||
:attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
|
||||
process-wide cache.
|
||||
"""
|
||||
return resolve_provider(provider_type, app=None, settings=get_settings())
|
||||
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, st
|
|||
try:
|
||||
cmd_start = content.rfind("Command:") + len("Command:")
|
||||
return True, content[cmd_start:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
except TypeError:
|
||||
return False, ""
|
||||
|
||||
return False, ""
|
||||
|
||||
|
|
@ -121,7 +121,6 @@ def is_filepath_extraction_request(
|
|||
if not user_has_filepaths and not system_has_extract:
|
||||
return False, "", ""
|
||||
|
||||
try:
|
||||
cmd_start = content.find("Command:") + len("Command:")
|
||||
output_marker = content.find("Output:", cmd_start)
|
||||
if output_marker == -1:
|
||||
|
|
@ -135,5 +134,3 @@ def is_filepath_extraction_request(
|
|||
output = output.split(marker)[0].strip()
|
||||
|
||||
return True, command, output
|
||||
except Exception:
|
||||
return False, "", ""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -15,41 +15,68 @@ class Role(StrEnum):
|
|||
system = "system"
|
||||
|
||||
|
||||
class ContentBlockText(BaseModel):
|
||||
class _AnthropicBlockBase(BaseModel):
|
||||
"""Pass through provider fields (e.g. ``cache_control``) for native transports."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ContentBlockText(_AnthropicBlockBase):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ContentBlockImage(BaseModel):
|
||||
class ContentBlockImage(_AnthropicBlockBase):
|
||||
type: Literal["image"]
|
||||
source: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockToolUse(BaseModel):
|
||||
class ContentBlockToolUse(_AnthropicBlockBase):
|
||||
type: Literal["tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockToolResult(BaseModel):
|
||||
class ContentBlockToolResult(_AnthropicBlockBase):
|
||||
type: Literal["tool_result"]
|
||||
tool_use_id: str
|
||||
content: str | list[Any] | dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockThinking(BaseModel):
|
||||
class ContentBlockThinking(_AnthropicBlockBase):
|
||||
type: Literal["thinking"]
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ContentBlockRedactedThinking(BaseModel):
|
||||
class ContentBlockRedactedThinking(_AnthropicBlockBase):
|
||||
type: Literal["redacted_thinking"]
|
||||
data: str
|
||||
|
||||
|
||||
class SystemContent(BaseModel):
|
||||
class ContentBlockServerToolUse(_AnthropicBlockBase):
|
||||
"""Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
|
||||
|
||||
type: Literal["server_tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
|
||||
type: Literal["web_search_tool_result"]
|
||||
tool_use_id: str
|
||||
content: Any
|
||||
|
||||
|
||||
class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
|
||||
type: Literal["web_fetch_tool_result"]
|
||||
tool_use_id: str
|
||||
content: Any
|
||||
|
||||
|
||||
class SystemContent(_AnthropicBlockBase):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
|
@ -68,12 +95,15 @@ class Message(BaseModel):
|
|||
| ContentBlockToolResult
|
||||
| ContentBlockThinking
|
||||
| ContentBlockRedactedThinking
|
||||
| ContentBlockServerToolUse
|
||||
| ContentBlockWebSearchToolResult
|
||||
| ContentBlockWebFetchToolResult
|
||||
]
|
||||
)
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
class Tool(_AnthropicBlockBase):
|
||||
name: str
|
||||
# Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
|
||||
# may omit ``input_schema`` because the provider owns the schema.
|
||||
|
|
@ -92,7 +122,12 @@ class ThinkingConfig(BaseModel):
|
|||
# Request Models
|
||||
# =============================================================================
|
||||
class MessagesRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
model: str
|
||||
# Internal routing / debug: accepted on parse but not serialized to providers.
|
||||
original_model: str | None = Field(default=None, exclude=True)
|
||||
resolved_provider_model: str | None = Field(default=None, exclude=True)
|
||||
max_tokens: int | None = None
|
||||
messages: list[Message]
|
||||
system: str | list[SystemContent] | None = None
|
||||
|
|
@ -105,13 +140,24 @@ class MessagesRequest(BaseModel):
|
|||
tools: list[Tool] | None = None
|
||||
tool_choice: dict[str, Any] | None = None
|
||||
thinking: ThinkingConfig | None = None
|
||||
# Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
|
||||
context_management: dict[str, Any] | None = None
|
||||
output_config: dict[str, Any] | None = None
|
||||
mcp_servers: list[dict[str, Any]] | None = None
|
||||
extra_body: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
model: str
|
||||
original_model: str | None = Field(default=None, exclude=True)
|
||||
resolved_provider_model: str | None = Field(default=None, exclude=True)
|
||||
messages: list[Message]
|
||||
system: str | list[SystemContent] | None = None
|
||||
tools: list[Tool] | None = None
|
||||
thinking: ThinkingConfig | None = None
|
||||
tool_choice: dict[str, Any] | None = None
|
||||
context_management: dict[str, Any] | None = None
|
||||
output_config: dict[str, Any] | None = None
|
||||
mcp_servers: list[dict[str, Any]] | None = None
|
||||
|
|
|
|||
|
|
@ -23,15 +23,31 @@ _SHUTDOWN_TIMEOUT_S = 5.0
|
|||
|
||||
|
||||
async def best_effort(
|
||||
name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S
|
||||
name: str,
|
||||
awaitable: Any,
|
||||
timeout_s: float = _SHUTDOWN_TIMEOUT_S,
|
||||
*,
|
||||
log_verbose_errors: bool = False,
|
||||
) -> None:
|
||||
"""Run a shutdown step with timeout; never raise to callers."""
|
||||
try:
|
||||
await asyncio.wait_for(awaitable, timeout=timeout_s)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
|
||||
logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
|
||||
except Exception as e:
|
||||
logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
|
||||
if log_verbose_errors:
|
||||
logger.warning(
|
||||
"Shutdown step failed: {}: {}: {}",
|
||||
name,
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Shutdown step failed: {}: exc_type={}",
|
||||
name,
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
|
||||
def warn_if_process_auth_token(settings: Settings) -> None:
|
||||
|
|
@ -73,20 +89,37 @@ class AppRuntime:
|
|||
self._publish_state()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
verbose = self.settings.log_api_error_tracebacks
|
||||
if self.message_handler is not None:
|
||||
try:
|
||||
self.message_handler.session_store.flush_pending_save()
|
||||
except Exception as e:
|
||||
logger.warning(f"Session store flush on shutdown: {e}")
|
||||
if verbose:
|
||||
logger.warning("Session store flush on shutdown: {}", e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Session store flush on shutdown: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
logger.info("Shutdown requested, cleaning up...")
|
||||
if self.messaging_platform:
|
||||
await best_effort("messaging_platform.stop", self.messaging_platform.stop())
|
||||
await best_effort(
|
||||
"messaging_platform.stop",
|
||||
self.messaging_platform.stop(),
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
if self.cli_manager:
|
||||
await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
|
||||
await best_effort(
|
||||
"cli_manager.stop_all",
|
||||
self.cli_manager.stop_all(),
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
if self._provider_registry is not None:
|
||||
await best_effort(
|
||||
"provider_registry.cleanup", self._provider_registry.cleanup()
|
||||
"provider_registry.cleanup",
|
||||
self._provider_registry.cleanup(),
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
await self._shutdown_limiter()
|
||||
logger.info("Server shut down cleanly")
|
||||
|
|
@ -110,6 +143,10 @@ class AppRuntime:
|
|||
whisper_device=self.settings.whisper_device,
|
||||
hf_token=self.settings.hf_token,
|
||||
nvidia_nim_api_key=self.settings.nvidia_nim_api_key,
|
||||
messaging_rate_limit=self.settings.messaging_rate_limit,
|
||||
messaging_rate_window=self.settings.messaging_rate_window,
|
||||
log_raw_messaging_content=self.settings.log_raw_messaging_content,
|
||||
log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -117,12 +154,24 @@ class AppRuntime:
|
|||
await self._start_message_handler()
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Messaging module import error: {e}")
|
||||
if self.settings.log_api_error_tracebacks:
|
||||
logger.warning("Messaging module import error: {}", e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Messaging module import error: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start messaging platform: {e}")
|
||||
if self.settings.log_api_error_tracebacks:
|
||||
logger.error("Failed to start messaging platform: {}", e)
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to start messaging platform: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
async def _start_message_handler(self) -> None:
|
||||
from cli.manager import CLISessionManager
|
||||
|
|
@ -151,10 +200,13 @@ class AppRuntime:
|
|||
allowed_dirs=allowed_dirs,
|
||||
plans_directory=plans_directory,
|
||||
claude_bin=self.settings.claude_cli_bin,
|
||||
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
|
||||
log_messaging_error_details=self.settings.log_messaging_error_details,
|
||||
)
|
||||
|
||||
session_store = SessionStore(
|
||||
storage_path=os.path.join(data_path, "sessions.json")
|
||||
storage_path=os.path.join(data_path, "sessions.json"),
|
||||
message_log_cap=self.settings.max_message_log_entries_per_chat,
|
||||
)
|
||||
platform = self.messaging_platform
|
||||
assert platform is not None
|
||||
|
|
@ -162,6 +214,11 @@ class AppRuntime:
|
|||
platform=platform,
|
||||
cli_manager=self.cli_manager,
|
||||
session_store=session_store,
|
||||
debug_platform_edits=self.settings.debug_platform_edits,
|
||||
debug_subagent_stack=self.settings.debug_subagent_stack,
|
||||
log_raw_messaging_content=self.settings.log_raw_messaging_content,
|
||||
log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
|
||||
log_messaging_error_details=self.settings.log_messaging_error_details,
|
||||
)
|
||||
self._restore_tree_state(session_store)
|
||||
|
||||
|
|
@ -201,18 +258,26 @@ class AppRuntime:
|
|||
self.app.state.cli_manager = self.cli_manager
|
||||
|
||||
async def _shutdown_limiter(self) -> None:
|
||||
verbose = self.settings.log_api_error_tracebacks
|
||||
try:
|
||||
from messaging.limiter import MessagingRateLimiter
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
logger.debug(
|
||||
"Rate limiter shutdown skipped (import failed): {}: {}",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Rate limiter shutdown skipped (import failed): exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
return
|
||||
|
||||
await best_effort(
|
||||
"MessagingRateLimiter.shutdown_instance",
|
||||
MessagingRateLimiter.shutdown_instance(),
|
||||
timeout_s=2.0,
|
||||
log_verbose_errors=verbose,
|
||||
)
|
||||
|
|
|
|||
126
api/services.py
126
api/services.py
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -13,6 +13,7 @@ from loguru import logger
|
|||
|
||||
from config.settings import Settings
|
||||
from core.anthropic import get_token_count, get_user_facing_error_message
|
||||
from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS
|
||||
from providers.base import BaseProvider
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
|
||||
|
|
@ -20,15 +21,67 @@ from .model_router import ModelRouter
|
|||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
from .web_server_tools import (
|
||||
from .web_tools.egress import WebFetchEgressPolicy
|
||||
from .web_tools.request import (
|
||||
is_web_server_tool_request,
|
||||
stream_web_server_tool_response,
|
||||
openai_chat_upstream_server_tool_error,
|
||||
)
|
||||
from .web_tools.streaming import stream_web_server_tool_response
|
||||
|
||||
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
||||
|
||||
ProviderGetter = Callable[[str], BaseProvider]
|
||||
|
||||
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
|
||||
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "deepseek"})
|
||||
|
||||
|
||||
def anthropic_sse_streaming_response(
|
||||
body: AsyncIterator[str],
|
||||
) -> StreamingResponse:
|
||||
"""Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
|
||||
return StreamingResponse(
|
||||
body,
|
||||
media_type="text/event-stream",
|
||||
headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
|
||||
"""HTTP status for uncaught non-provider failures (stable client contract)."""
|
||||
return 500
|
||||
|
||||
|
||||
def _log_unexpected_service_exception(
|
||||
settings: Settings,
|
||||
exc: BaseException,
|
||||
*,
|
||||
context: str,
|
||||
request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Log service-layer failures without echoing exception text unless opted in."""
|
||||
if settings.log_api_error_tracebacks:
|
||||
if request_id is not None:
|
||||
logger.error("{} request_id={}: {}", context, request_id, exc)
|
||||
else:
|
||||
logger.error("{}: {}", context, exc)
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
if request_id is not None:
|
||||
logger.error(
|
||||
"{} request_id={} exc_type={}",
|
||||
context,
|
||||
request_id,
|
||||
type(exc).__name__,
|
||||
)
|
||||
else:
|
||||
logger.error("{} exc_type={}", context, type(exc).__name__)
|
||||
|
||||
|
||||
def _require_non_empty_messages(messages: list[Any]) -> None:
|
||||
if not messages:
|
||||
raise InvalidRequestError("messages cannot be empty")
|
||||
|
||||
|
||||
class ClaudeProxyService:
|
||||
"""Coordinate request optimization, model routing, token count, and providers."""
|
||||
|
|
@ -48,25 +101,35 @@ class ClaudeProxyService:
|
|||
def create_message(self, request_data: MessagesRequest) -> object:
|
||||
"""Create a message response or streaming response."""
|
||||
try:
|
||||
if not request_data.messages:
|
||||
raise InvalidRequestError("messages cannot be empty")
|
||||
_require_non_empty_messages(request_data.messages)
|
||||
|
||||
routed = self._model_router.resolve_messages_request(request_data)
|
||||
if is_web_server_tool_request(routed.request):
|
||||
if routed.resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
|
||||
tool_err = openai_chat_upstream_server_tool_error(
|
||||
routed.request,
|
||||
web_tools_enabled=self._settings.enable_web_server_tools,
|
||||
)
|
||||
if tool_err is not None:
|
||||
raise InvalidRequestError(tool_err)
|
||||
|
||||
if self._settings.enable_web_server_tools and is_web_server_tool_request(
|
||||
routed.request
|
||||
):
|
||||
input_tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
)
|
||||
logger.info("Optimization: Handling Anthropic web server tool")
|
||||
return StreamingResponse(
|
||||
egress = WebFetchEgressPolicy(
|
||||
allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
|
||||
allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
|
||||
)
|
||||
return anthropic_sse_streaming_response(
|
||||
stream_web_server_tool_response(
|
||||
routed.request, input_tokens=input_tokens
|
||||
routed.request,
|
||||
input_tokens=input_tokens,
|
||||
web_fetch_egress=egress,
|
||||
verbose_client_errors=self._settings.log_api_error_tracebacks,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
optimized = try_optimizations(routed.request, self._settings)
|
||||
|
|
@ -75,6 +138,10 @@ class ClaudeProxyService:
|
|||
logger.debug("No optimization matched, routing to provider")
|
||||
|
||||
provider = self._provider_getter(routed.resolved.provider_id)
|
||||
provider.preflight_stream(
|
||||
routed.request,
|
||||
thinking_enabled=routed.resolved.thinking_enabled,
|
||||
)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
logger.info(
|
||||
|
|
@ -83,6 +150,7 @@ class ClaudeProxyService:
|
|||
routed.request.model,
|
||||
len(routed.request.messages),
|
||||
)
|
||||
if self._settings.log_raw_api_payloads:
|
||||
logger.debug(
|
||||
"FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
|
||||
)
|
||||
|
|
@ -90,27 +158,23 @@ class ClaudeProxyService:
|
|||
input_tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
)
|
||||
return StreamingResponse(
|
||||
return anthropic_sse_streaming_response(
|
||||
provider.stream_response(
|
||||
routed.request,
|
||||
input_tokens=input_tokens,
|
||||
request_id=request_id,
|
||||
thinking_enabled=routed.resolved.thinking_enabled,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
|
||||
_log_unexpected_service_exception(
|
||||
self._settings, e, context="CREATE_MESSAGE_ERROR"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=getattr(e, "status_code", 500),
|
||||
status_code=_http_status_for_unexpected_service_exception(e),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
|
||||
|
|
@ -119,6 +183,7 @@ class ClaudeProxyService:
|
|||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
with logger.contextualize(request_id=request_id):
|
||||
try:
|
||||
_require_non_empty_messages(request_data.messages)
|
||||
routed = self._model_router.resolve_token_count_request(request_data)
|
||||
tokens = self._token_counter(
|
||||
routed.request.messages, routed.request.system, routed.request.tools
|
||||
|
|
@ -131,13 +196,16 @@ class ClaudeProxyService:
|
|||
tokens,
|
||||
)
|
||||
return TokenCountResponse(input_tokens=tokens)
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
|
||||
request_id,
|
||||
get_user_facing_error_message(e),
|
||||
traceback.format_exc(),
|
||||
_log_unexpected_service_exception(
|
||||
self._settings,
|
||||
e,
|
||||
context="COUNT_TOKENS_ERROR",
|
||||
request_id=request_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=get_user_facing_error_message(e)
|
||||
status_code=_http_status_for_unexpected_service_exception(e),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
|
|
|
|||
48
api/validation_log.py
Normal file
48
api/validation_log.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def summarize_request_validation_body(
|
||||
body: Any,
|
||||
) -> tuple[list[dict[str, Any]], list[str]]:
|
||||
"""Return message shape summary and tool name list for debug logs."""
|
||||
messages = body.get("messages") if isinstance(body, dict) else None
|
||||
message_summary: list[dict[str, Any]] = []
|
||||
if isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
message_summary.append({"message_kind": type(msg).__name__})
|
||||
continue
|
||||
content = msg.get("content")
|
||||
item: dict[str, Any] = {
|
||||
"role": msg.get("role"),
|
||||
"content_kind": type(content).__name__,
|
||||
}
|
||||
if isinstance(content, list):
|
||||
item["block_types"] = [
|
||||
block.get("type", "dict")
|
||||
if isinstance(block, dict)
|
||||
else type(block).__name__
|
||||
for block in content[:12]
|
||||
]
|
||||
item["block_keys"] = [
|
||||
sorted(str(key) for key in block)[:12]
|
||||
for block in content[:5]
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
elif isinstance(content, str):
|
||||
item["content_length"] = len(content)
|
||||
message_summary.append(item)
|
||||
|
||||
tool_names: list[str] = []
|
||||
if isinstance(body, dict) and isinstance(body.get("tools"), list):
|
||||
tool_names = [
|
||||
str(tool.get("name", ""))
|
||||
for tool in body["tools"]
|
||||
if isinstance(tool, dict)
|
||||
]
|
||||
|
||||
return message_summary, tool_names
|
||||
|
|
@ -1,331 +1,22 @@
|
|||
"""Local handlers for Anthropic web server tools.
|
||||
|
||||
OpenAI-compatible upstreams can emit regular function calls, but Anthropic's
|
||||
web tools are server-side: the API response itself must include the tool result.
|
||||
"""
|
||||
"""Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from .models.anthropic import MessagesRequest
|
||||
|
||||
_REQUEST_TIMEOUT_S = 20.0
|
||||
_MAX_SEARCH_RESULTS = 10
|
||||
_MAX_FETCH_CHARS = 24_000
|
||||
|
||||
|
||||
class _SearchResultParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results: list[dict[str, str]] = []
|
||||
self._href: str | None = None
|
||||
self._title_parts: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag != "a":
|
||||
return
|
||||
href = dict(attrs).get("href")
|
||||
if not href or "uddg=" not in href:
|
||||
return
|
||||
parsed = urlparse(href)
|
||||
query = parse_qs(parsed.query)
|
||||
uddg = query.get("uddg", [""])[0]
|
||||
if not uddg:
|
||||
return
|
||||
self._href = unquote(uddg)
|
||||
self._title_parts = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._href is not None:
|
||||
self._title_parts.append(data)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag != "a" or self._href is None:
|
||||
return
|
||||
title = " ".join("".join(self._title_parts).split())
|
||||
if title and not any(result["url"] == self._href for result in self.results):
|
||||
self.results.append({"title": html.unescape(title), "url": self._href})
|
||||
self._href = None
|
||||
self._title_parts = []
|
||||
|
||||
|
||||
class _HTMLTextParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.title = ""
|
||||
self.text_parts: list[str] = []
|
||||
self._in_title = False
|
||||
self._skip_depth = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag in {"script", "style", "noscript"}:
|
||||
self._skip_depth += 1
|
||||
elif tag == "title":
|
||||
self._in_title = True
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
||||
self._skip_depth -= 1
|
||||
elif tag == "title":
|
||||
self._in_title = False
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
text = " ".join(data.split())
|
||||
if not text:
|
||||
return
|
||||
if self._in_title:
|
||||
self.title = f"{self.title} {text}".strip()
|
||||
elif not self._skip_depth:
|
||||
self.text_parts.append(text)
|
||||
|
||||
|
||||
def _format_event(event_type: str, data: dict[str, Any]) -> str:
|
||||
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _content_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(str(item.get("text", "")))
|
||||
else:
|
||||
parts.append(str(getattr(item, "text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _request_text(request: MessagesRequest) -> str:
|
||||
return "\n".join(_content_text(message.content) for message in request.messages)
|
||||
|
||||
|
||||
def _web_tool_name(request: MessagesRequest) -> str | None:
|
||||
for tool in request.tools or []:
|
||||
name = tool.name
|
||||
tool_type = tool.type or ""
|
||||
if name in {"web_search", "web_fetch"} or tool_type.startswith("web_"):
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
||||
return _web_tool_name(request) in {"web_search", "web_fetch"}
|
||||
|
||||
|
||||
def _extract_query(text: str) -> str:
|
||||
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip().strip("\"'")
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _extract_url(text: str) -> str:
|
||||
match = re.search(r"https?://\S+", text)
|
||||
return match.group(0).rstrip(").,]") if match else text.strip()
|
||||
|
||||
|
||||
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"https://lite.duckduckgo.com/lite/",
|
||||
params={"q": query},
|
||||
from api.web_tools.egress import (
|
||||
WebFetchEgressPolicy,
|
||||
WebFetchEgressViolation,
|
||||
enforce_web_fetch_egress,
|
||||
)
|
||||
response.raise_for_status()
|
||||
from api.web_tools.request import is_web_server_tool_request
|
||||
from api.web_tools.streaming import stream_web_server_tool_response
|
||||
|
||||
parser = _SearchResultParser()
|
||||
parser.feed(response.text)
|
||||
return parser.results[:_MAX_SEARCH_RESULTS]
|
||||
|
||||
|
||||
async def _run_web_fetch(url: str) -> dict[str, str]:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0"},
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "text/plain")
|
||||
title = url
|
||||
data = response.text
|
||||
if "html" in content_type.lower():
|
||||
parser = _HTMLTextParser()
|
||||
parser.feed(response.text)
|
||||
title = parser.title or url
|
||||
data = "\n".join(parser.text_parts)
|
||||
return {
|
||||
"url": str(response.url),
|
||||
"title": title,
|
||||
"media_type": "text/plain",
|
||||
"data": data[:_MAX_FETCH_CHARS],
|
||||
}
|
||||
|
||||
|
||||
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
||||
if not results:
|
||||
return f"No web search results found for: {query}"
|
||||
lines = [f"Search results for: {query}"]
|
||||
for index, result in enumerate(results, start=1):
|
||||
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def stream_web_server_tool_response(
|
||||
request: MessagesRequest, input_tokens: int
|
||||
) -> AsyncIterator[str]:
|
||||
tool_name = _web_tool_name(request)
|
||||
if tool_name is None:
|
||||
return
|
||||
|
||||
text = _request_text(request)
|
||||
message_id = f"msg_{uuid.uuid4()}"
|
||||
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
||||
output_tokens = 1
|
||||
usage_key = (
|
||||
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
||||
)
|
||||
tool_input = (
|
||||
{"query": _extract_query(text)}
|
||||
if tool_name == "web_search"
|
||||
else {"url": _extract_url(text)}
|
||||
)
|
||||
|
||||
yield _format_event(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": request.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "server_tool_use",
|
||||
"id": tool_id,
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "web_search":
|
||||
query = str(tool_input["query"])
|
||||
results = await _run_web_search(query)
|
||||
result_content: Any = [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
}
|
||||
for result in results
|
||||
__all__ = [
|
||||
"WebFetchEgressPolicy",
|
||||
"WebFetchEgressViolation",
|
||||
"enforce_web_fetch_egress",
|
||||
"httpx",
|
||||
"is_web_server_tool_request",
|
||||
"stream_web_server_tool_response",
|
||||
]
|
||||
summary = _search_summary(query, results)
|
||||
result_block_type = "web_search_tool_result"
|
||||
else:
|
||||
fetched = await _run_web_fetch(str(tool_input["url"]))
|
||||
result_content = {
|
||||
"type": "web_fetch_result",
|
||||
"url": fetched["url"],
|
||||
"content": {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": fetched["media_type"],
|
||||
"data": fetched["data"],
|
||||
},
|
||||
"title": fetched["title"],
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
"retrieved_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
||||
result_block_type = "web_fetch_tool_result"
|
||||
except Exception as error:
|
||||
result_block_type = (
|
||||
"web_search_tool_result"
|
||||
if tool_name == "web_search"
|
||||
else "web_fetch_tool_result"
|
||||
)
|
||||
error_type = (
|
||||
"web_search_tool_result_error"
|
||||
if tool_name == "web_search"
|
||||
else "web_fetch_tool_error"
|
||||
)
|
||||
result_content = {"type": error_type, "error_code": "unavailable"}
|
||||
summary = f"{tool_name} failed: {type(error).__name__}"
|
||||
|
||||
output_tokens = max(1, len(summary) // 4)
|
||||
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": result_block_type,
|
||||
"tool_use_id": tool_id,
|
||||
"content": result_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": {"type": "text", "text": summary},
|
||||
},
|
||||
)
|
||||
yield _format_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
||||
)
|
||||
yield _format_event(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"server_tool_use": {usage_key: 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield _format_event("message_stop", {"type": "message_stop"})
|
||||
|
|
|
|||
17
api/web_tools/__init__.py
Normal file
17
api/web_tools/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
|
||||
|
||||
from .egress import (
|
||||
WebFetchEgressPolicy,
|
||||
WebFetchEgressViolation,
|
||||
enforce_web_fetch_egress,
|
||||
)
|
||||
from .request import is_web_server_tool_request
|
||||
from .streaming import stream_web_server_tool_response
|
||||
|
||||
__all__ = [
|
||||
"WebFetchEgressPolicy",
|
||||
"WebFetchEgressViolation",
|
||||
"enforce_web_fetch_egress",
|
||||
"is_web_server_tool_request",
|
||||
"stream_web_server_tool_response",
|
||||
]
|
||||
15
api/web_tools/constants.py
Normal file
15
api/web_tools/constants.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""Limits and defaults for outbound web server tool HTTP."""
|
||||
|
||||
_REQUEST_TIMEOUT_S = 20.0
|
||||
_MAX_SEARCH_RESULTS = 10
|
||||
_MAX_FETCH_CHARS = 24_000
|
||||
# Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
|
||||
_MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
|
||||
# Drain at most this many bytes from redirect responses before following Location.
|
||||
_REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
|
||||
_MAX_WEB_FETCH_REDIRECTS = 10
|
||||
_WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
|
||||
|
||||
_WEB_TOOL_HTTP_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
|
||||
}
|
||||
99
api/web_tools/egress.py
Normal file
99
api/web_tools/egress.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WebFetchEgressPolicy:
|
||||
"""Egress rules for user-influenced web_fetch URLs."""
|
||||
|
||||
allow_private_network_targets: bool
|
||||
allowed_schemes: frozenset[str]
|
||||
|
||||
|
||||
class WebFetchEgressViolation(ValueError):
|
||||
"""Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
|
||||
|
||||
|
||||
def _port_for_url(parsed) -> int:
|
||||
if parsed.port is not None:
|
||||
return parsed.port
|
||||
return 443 if (parsed.scheme or "").lower() == "https" else 80
|
||||
|
||||
|
||||
def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
|
||||
try:
|
||||
return socket.getaddrinfo(
|
||||
host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
|
||||
)
|
||||
except OSError as exc:
|
||||
raise WebFetchEgressViolation(
|
||||
f"Could not resolve host {host!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
|
||||
def get_validated_stream_addrinfos_for_egress(
|
||||
url: str, policy: WebFetchEgressPolicy
|
||||
) -> list[tuple]:
|
||||
"""Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
|
||||
|
||||
Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
|
||||
server cannot rebind to a disallowed address between resolution and the TCP
|
||||
connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
scheme = (parsed.scheme or "").lower()
|
||||
if scheme not in policy.allowed_schemes:
|
||||
raise WebFetchEgressViolation(
|
||||
f"URL scheme {scheme!r} is not allowed for web_fetch"
|
||||
)
|
||||
|
||||
host = parsed.hostname
|
||||
if host is None or host == "":
|
||||
raise WebFetchEgressViolation("web_fetch URL must include a host")
|
||||
|
||||
port = _port_for_url(parsed)
|
||||
|
||||
if policy.allow_private_network_targets:
|
||||
return _stream_getaddrinfo_or_raise(host, port)
|
||||
|
||||
host_lower = host.lower()
|
||||
if host_lower == "localhost" or host_lower.endswith(".localhost"):
|
||||
raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
|
||||
if host_lower.endswith(".local"):
|
||||
raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
|
||||
|
||||
try:
|
||||
parsed_ip = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
parsed_ip = None
|
||||
|
||||
if parsed_ip is not None:
|
||||
if not parsed_ip.is_global:
|
||||
raise WebFetchEgressViolation(
|
||||
f"Non-public IP host {host!r} is not allowed for web_fetch"
|
||||
)
|
||||
return _stream_getaddrinfo_or_raise(host, port)
|
||||
|
||||
infos = _stream_getaddrinfo_or_raise(host, port)
|
||||
for *_, sockaddr in infos:
|
||||
addr = sockaddr[0]
|
||||
try:
|
||||
resolved = ipaddress.ip_address(addr)
|
||||
except ValueError:
|
||||
continue
|
||||
if not resolved.is_global:
|
||||
raise WebFetchEgressViolation(
|
||||
f"Host {host!r} resolves to a non-public address ({resolved})"
|
||||
)
|
||||
return infos
|
||||
|
||||
|
||||
def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
|
||||
"""Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
|
||||
get_validated_stream_addrinfos_for_egress(url, policy)
|
||||
278
api/web_tools/outbound.py
Normal file
278
api/web_tools/outbound.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from collections.abc import AsyncIterator
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
from aiohttp.abc import AbstractResolver, ResolveResult
|
||||
from loguru import logger
|
||||
|
||||
from . import constants
|
||||
from .constants import (
|
||||
_MAX_FETCH_CHARS,
|
||||
_MAX_SEARCH_RESULTS,
|
||||
_REDIRECT_RESPONSE_BODY_CAP_BYTES,
|
||||
_REQUEST_TIMEOUT_S,
|
||||
_WEB_FETCH_REDIRECT_STATUSES,
|
||||
_WEB_TOOL_HTTP_HEADERS,
|
||||
)
|
||||
from .egress import (
|
||||
WebFetchEgressPolicy,
|
||||
WebFetchEgressViolation,
|
||||
get_validated_stream_addrinfos_for_egress,
|
||||
)
|
||||
from .parsers import HTMLTextParser, SearchResultParser
|
||||
|
||||
|
||||
def _safe_public_host_for_logs(url: str) -> str:
|
||||
host = urlparse(url).hostname or ""
|
||||
return host[:253]
|
||||
|
||||
|
||||
def _log_web_tool_failure(
|
||||
tool_name: str,
|
||||
error: BaseException,
|
||||
*,
|
||||
fetch_url: str | None = None,
|
||||
) -> None:
|
||||
exc_type = type(error).__name__
|
||||
if isinstance(error, WebFetchEgressViolation):
|
||||
host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
|
||||
logger.warning(
|
||||
"web_tool_egress_rejected tool={} exc_type={} host={!r}",
|
||||
tool_name,
|
||||
exc_type,
|
||||
host,
|
||||
)
|
||||
return
|
||||
if tool_name == "web_fetch" and fetch_url:
|
||||
logger.warning(
|
||||
"web_tool_failure tool={} exc_type={} host={!r}",
|
||||
tool_name,
|
||||
exc_type,
|
||||
_safe_public_host_for_logs(fetch_url),
|
||||
)
|
||||
else:
|
||||
logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
|
||||
|
||||
|
||||
def _web_tool_client_error_summary(
|
||||
tool_name: str,
|
||||
error: BaseException,
|
||||
*,
|
||||
verbose: bool,
|
||||
) -> str:
|
||||
if verbose:
|
||||
return f"{tool_name} failed: {type(error).__name__}"
|
||||
return "Web tool request failed."
|
||||
|
||||
|
||||
async def _iter_response_body_under_cap(
|
||||
response: httpx.Response, max_bytes: int
|
||||
) -> AsyncIterator[bytes]:
|
||||
if max_bytes <= 0:
|
||||
return
|
||||
received = 0
|
||||
async for chunk in response.aiter_bytes(chunk_size=65_536):
|
||||
if received >= max_bytes:
|
||||
break
|
||||
remaining = max_bytes - received
|
||||
if len(chunk) <= remaining:
|
||||
received += len(chunk)
|
||||
yield chunk
|
||||
if received >= max_bytes:
|
||||
break
|
||||
else:
|
||||
yield chunk[:remaining]
|
||||
break
|
||||
|
||||
|
||||
async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
|
||||
async for _ in _iter_response_body_under_cap(response, max_bytes):
|
||||
pass
|
||||
|
||||
|
||||
async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
|
||||
return b"".join(
|
||||
[piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
|
||||
)
|
||||
|
||||
|
||||
_NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
|
||||
_NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
|
||||
|
||||
|
||||
def getaddrinfo_rows_to_resolve_results(
|
||||
host: str, addrinfos: list[tuple]
|
||||
) -> list[ResolveResult]:
|
||||
"""Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
|
||||
out: list[ResolveResult] = []
|
||||
for family, _type, proto, _canon, sockaddr in addrinfos:
|
||||
if family == socket.AF_INET6:
|
||||
if len(sockaddr) < 3:
|
||||
continue
|
||||
if sockaddr[3]:
|
||||
resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
|
||||
else:
|
||||
resolved_host, port = sockaddr[:2]
|
||||
else:
|
||||
assert family == socket.AF_INET, family
|
||||
resolved_host, port = sockaddr[0], sockaddr[1]
|
||||
resolved_host = str(resolved_host)
|
||||
port = int(port)
|
||||
out.append(
|
||||
ResolveResult(
|
||||
hostname=host,
|
||||
host=resolved_host,
|
||||
port=int(port),
|
||||
family=family,
|
||||
proto=proto,
|
||||
flags=_NUMERIC_RESOLVE_FLAGS,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class _PinnedEgressStaticResolver(AbstractResolver):
|
||||
"""Return only pre-validated :class:`ResolveResult` for the outbound request."""
|
||||
|
||||
def __init__(self, results: list[ResolveResult]) -> None:
|
||||
self._results = results
|
||||
|
||||
async def resolve(
|
||||
self, host: str, port: int = 0, family: int = socket.AF_INET
|
||||
) -> list[ResolveResult]:
|
||||
return self._results
|
||||
|
||||
async def close(self) -> None: # pragma: no cover - aiohttp contract
|
||||
return
|
||||
|
||||
|
||||
async def _read_aiohttp_body_capped(
|
||||
response: aiohttp.ClientResponse, max_bytes: int
|
||||
) -> bytes:
|
||||
received = 0
|
||||
parts: list[bytes] = []
|
||||
async for chunk in response.content.iter_chunked(65_536):
|
||||
if received >= max_bytes:
|
||||
break
|
||||
remaining = max_bytes - received
|
||||
if len(chunk) <= remaining:
|
||||
received += len(chunk)
|
||||
parts.append(chunk)
|
||||
else:
|
||||
parts.append(chunk[:remaining])
|
||||
break
|
||||
return b"".join(parts)
|
||||
|
||||
|
||||
async def _drain_aiohttp_body_capped(
|
||||
response: aiohttp.ClientResponse, max_bytes: int
|
||||
) -> None:
|
||||
if max_bytes <= 0:
|
||||
return
|
||||
received = 0
|
||||
async for chunk in response.content.iter_chunked(65_536):
|
||||
received += len(chunk)
|
||||
if received >= max_bytes:
|
||||
break
|
||||
|
||||
|
||||
async def _run_web_search(query: str) -> list[dict[str, str]]:
|
||||
async with (
|
||||
httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT_S,
|
||||
follow_redirects=True,
|
||||
headers=_WEB_TOOL_HTTP_HEADERS,
|
||||
) as client,
|
||||
client.stream(
|
||||
"GET",
|
||||
"https://lite.duckduckgo.com/lite/",
|
||||
params={"q": query},
|
||||
) as response,
|
||||
):
|
||||
response.raise_for_status()
|
||||
body_bytes = await _read_response_body_capped(
|
||||
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
|
||||
)
|
||||
text = body_bytes.decode("utf-8", errors="replace")
|
||||
parser = SearchResultParser()
|
||||
parser.feed(text)
|
||||
return parser.results[:_MAX_SEARCH_RESULTS]
|
||||
|
||||
|
||||
async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
|
||||
"""Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
|
||||
current_url = url
|
||||
redirect_hops = 0
|
||||
timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
|
||||
|
||||
while True:
|
||||
addr_infos = await asyncio.to_thread(
|
||||
get_validated_stream_addrinfos_for_egress, current_url, egress
|
||||
)
|
||||
host = urlparse(current_url).hostname or ""
|
||||
results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
|
||||
resolver = _PinnedEgressStaticResolver(results)
|
||||
connector = TCPConnector(
|
||||
resolver=resolver,
|
||||
force_close=True,
|
||||
)
|
||||
try:
|
||||
async with (
|
||||
ClientSession(
|
||||
timeout=timeout,
|
||||
headers=_WEB_TOOL_HTTP_HEADERS,
|
||||
connector=connector,
|
||||
) as session,
|
||||
session.get(current_url, allow_redirects=False) as response,
|
||||
):
|
||||
if response.status in _WEB_FETCH_REDIRECT_STATUSES:
|
||||
await _drain_aiohttp_body_capped(
|
||||
response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
|
||||
)
|
||||
if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
|
||||
raise WebFetchEgressViolation(
|
||||
"web_fetch exceeded maximum redirects "
|
||||
f"({constants._MAX_WEB_FETCH_REDIRECTS})"
|
||||
)
|
||||
location = response.headers.get("location")
|
||||
if not location or not location.strip():
|
||||
raise WebFetchEgressViolation(
|
||||
"web_fetch redirect response missing Location header"
|
||||
)
|
||||
current_url = urljoin(str(response.url), location.strip())
|
||||
redirect_hops += 1
|
||||
continue
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "text/plain")
|
||||
final_url = str(response.url)
|
||||
encoding = response.get_encoding() or "utf-8"
|
||||
body_bytes = await _read_aiohttp_body_capped(
|
||||
response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
|
||||
)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
break
|
||||
|
||||
text = body_bytes.decode(encoding, errors="replace")
|
||||
title = final_url
|
||||
data = text
|
||||
if "html" in content_type.lower():
|
||||
parser = HTMLTextParser()
|
||||
parser.feed(text)
|
||||
title = parser.title or final_url
|
||||
data = "\n".join(parser.text_parts)
|
||||
return {
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"media_type": "text/plain",
|
||||
"data": data[:_MAX_FETCH_CHARS],
|
||||
}
|
||||
104
api/web_tools/parsers.py
Normal file
104
api/web_tools/parsers.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""HTML parsing for web_search / web_fetch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import re
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
|
||||
class SearchResultParser(HTMLParser):
|
||||
"""DuckDuckGo lite HTML: extract result links and titles."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results: list[dict[str, str]] = []
|
||||
self._href: str | None = None
|
||||
self._title_parts: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag != "a":
|
||||
return
|
||||
href = dict(attrs).get("href")
|
||||
if not href or "uddg=" not in href:
|
||||
return
|
||||
parsed = urlparse(href)
|
||||
query = parse_qs(parsed.query)
|
||||
uddg = query.get("uddg", [""])[0]
|
||||
if not uddg:
|
||||
return
|
||||
self._href = unquote(uddg)
|
||||
self._title_parts = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._href is not None:
|
||||
self._title_parts.append(data)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag != "a" or self._href is None:
|
||||
return
|
||||
title = " ".join("".join(self._title_parts).split())
|
||||
if title and not any(result["url"] == self._href for result in self.results):
|
||||
self.results.append({"title": html.unescape(title), "url": self._href})
|
||||
self._href = None
|
||||
self._title_parts = []
|
||||
|
||||
|
||||
class HTMLTextParser(HTMLParser):
|
||||
"""Strip scripts/styles and collect visible text + title for fetch previews."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.title = ""
|
||||
self.text_parts: list[str] = []
|
||||
self._in_title = False
|
||||
self._skip_depth = 0
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag in {"script", "style", "noscript"}:
|
||||
self._skip_depth += 1
|
||||
elif tag == "title":
|
||||
self._in_title = True
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in {"script", "style", "noscript"} and self._skip_depth:
|
||||
self._skip_depth -= 1
|
||||
elif tag == "title":
|
||||
self._in_title = False
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
text = " ".join(data.split())
|
||||
if not text:
|
||||
return
|
||||
if self._in_title:
|
||||
self.title = f"{self.title} {text}".strip()
|
||||
elif not self._skip_depth:
|
||||
self.text_parts.append(text)
|
||||
|
||||
|
||||
def content_text(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(str(item.get("text", "")))
|
||||
else:
|
||||
parts.append(str(getattr(item, "text", "")))
|
||||
return "\n".join(part for part in parts if part)
|
||||
return str(content)
|
||||
|
||||
|
||||
def extract_query(text: str) -> str:
|
||||
match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip().strip("\"'")
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_url(text: str) -> str:
|
||||
match = re.search(r"https?://\S+", text)
|
||||
return match.group(0).rstrip(").,]") if match else text.strip()
|
||||
87
api/web_tools/request.py
Normal file
87
api/web_tools/request.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Detect forced Anthropic web server tool requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from api.models.anthropic import MessagesRequest, Tool
|
||||
|
||||
|
||||
def request_text(request: MessagesRequest) -> str:
|
||||
"""Join all user/assistant message content into one string for tool input parsing."""
|
||||
from .parsers import content_text
|
||||
|
||||
return "\n".join(content_text(message.content) for message in request.messages)
|
||||
|
||||
|
||||
def forced_tool_turn_text(request: MessagesRequest) -> str:
|
||||
"""Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
|
||||
if not request.messages:
|
||||
return ""
|
||||
|
||||
from .parsers import content_text
|
||||
|
||||
for message in reversed(request.messages):
|
||||
if message.role == "user":
|
||||
return content_text(message.content)
|
||||
return ""
|
||||
|
||||
|
||||
def forced_server_tool_name(request: MessagesRequest) -> str | None:
|
||||
"""Return web_search or web_fetch only when tool_choice forces that server tool."""
|
||||
tc = request.tool_choice
|
||||
if not isinstance(tc, dict):
|
||||
return None
|
||||
if tc.get("type") != "tool":
|
||||
return None
|
||||
name = tc.get("name")
|
||||
if name in {"web_search", "web_fetch"}:
|
||||
return str(name)
|
||||
return None
|
||||
|
||||
|
||||
def has_tool_named(request: MessagesRequest, name: str) -> bool:
|
||||
return any(tool.name == name for tool in request.tools or [])
|
||||
|
||||
|
||||
def is_web_server_tool_request(request: MessagesRequest) -> bool:
|
||||
"""True when the client forces a web server tool via tool_choice (not merely listed)."""
|
||||
forced = forced_server_tool_name(request)
|
||||
if forced is None:
|
||||
return False
|
||||
return has_tool_named(request, forced)
|
||||
|
||||
|
||||
def is_anthropic_server_tool_definition(tool: Tool) -> bool:
|
||||
"""Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
|
||||
name = (tool.name or "").strip()
|
||||
if name in ("web_search", "web_fetch"):
|
||||
return True
|
||||
typ = tool.type
|
||||
if isinstance(typ, str):
|
||||
return typ.startswith("web_search") or typ.startswith("web_fetch")
|
||||
return False
|
||||
|
||||
|
||||
def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
|
||||
"""True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
|
||||
return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
|
||||
|
||||
|
||||
def openai_chat_upstream_server_tool_error(
|
||||
request: MessagesRequest, *, web_tools_enabled: bool
|
||||
) -> str | None:
|
||||
"""Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
|
||||
forced = forced_server_tool_name(request)
|
||||
if forced and not web_tools_enabled:
|
||||
return (
|
||||
f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
|
||||
"disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them or use a native Anthropic "
|
||||
"Messages transport (e.g. open_router, ollama, lmstudio)."
|
||||
)
|
||||
if not forced and has_listed_anthropic_server_tools(request):
|
||||
return (
|
||||
"OpenAI Chat upstreams (NVIDIA NIM, DeepSeek) cannot use listed Anthropic server tools "
|
||||
"(web_search / web_fetch) without the local web server tool handler. Use a native "
|
||||
"Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
|
||||
"tool_choice, or remove these tools from the request."
|
||||
)
|
||||
return None
|
||||
206
api/web_tools/streaming.py
Normal file
206
api/web_tools/streaming.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""SSE streaming for local web_search / web_fetch server tool results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from api.models.anthropic import MessagesRequest
|
||||
from core.anthropic.server_tool_sse import (
|
||||
SERVER_TOOL_USE,
|
||||
WEB_FETCH_TOOL_ERROR,
|
||||
WEB_FETCH_TOOL_RESULT,
|
||||
WEB_SEARCH_TOOL_RESULT,
|
||||
WEB_SEARCH_TOOL_RESULT_ERROR,
|
||||
)
|
||||
from core.anthropic.sse import format_sse_event
|
||||
|
||||
from . import outbound
|
||||
from .constants import _MAX_FETCH_CHARS
|
||||
from .egress import WebFetchEgressPolicy
|
||||
from .parsers import extract_query, extract_url
|
||||
from .request import (
|
||||
forced_server_tool_name,
|
||||
forced_tool_turn_text,
|
||||
has_tool_named,
|
||||
)
|
||||
|
||||
|
||||
def _search_summary(query: str, results: list[dict[str, str]]) -> str:
|
||||
if not results:
|
||||
return f"No web search results found for: {query}"
|
||||
lines = [f"Search results for: {query}"]
|
||||
for index, result in enumerate(results, start=1):
|
||||
lines.append(f"{index}. {result['title']}\n{result['url']}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
async def stream_web_server_tool_response(
|
||||
request: MessagesRequest,
|
||||
input_tokens: int,
|
||||
*,
|
||||
web_fetch_egress: WebFetchEgressPolicy,
|
||||
verbose_client_errors: bool = False,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
|
||||
|
||||
When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full
|
||||
hosted Anthropic citation or encrypted-content pipeline.
|
||||
"""
|
||||
tool_name = forced_server_tool_name(request)
|
||||
if tool_name is None or not has_tool_named(request, tool_name):
|
||||
return
|
||||
|
||||
text = forced_tool_turn_text(request)
|
||||
message_id = f"msg_{uuid.uuid4()}"
|
||||
tool_id = f"srvtoolu_{uuid.uuid4().hex}"
|
||||
usage_key = (
|
||||
"web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
|
||||
)
|
||||
tool_input = (
|
||||
{"query": extract_query(text)}
|
||||
if tool_name == "web_search"
|
||||
else {"url": extract_url(text)}
|
||||
)
|
||||
_result_block_for_tool = {
|
||||
"web_search": WEB_SEARCH_TOOL_RESULT,
|
||||
"web_fetch": WEB_FETCH_TOOL_RESULT,
|
||||
}
|
||||
_error_payload_type_for_tool = {
|
||||
"web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
|
||||
"web_fetch": WEB_FETCH_TOOL_ERROR,
|
||||
}
|
||||
|
||||
yield format_sse_event(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": request.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": input_tokens, "output_tokens": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": SERVER_TOOL_USE,
|
||||
"id": tool_id,
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 0}
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "web_search":
|
||||
query = str(tool_input["query"])
|
||||
results = await outbound._run_web_search(query)
|
||||
result_content: Any = [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
summary = _search_summary(query, results)
|
||||
result_block_type = WEB_SEARCH_TOOL_RESULT
|
||||
else:
|
||||
fetched = await outbound._run_web_fetch(
|
||||
str(tool_input["url"]), web_fetch_egress
|
||||
)
|
||||
result_content = {
|
||||
"type": "web_fetch_result",
|
||||
"url": fetched["url"],
|
||||
"content": {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": fetched["media_type"],
|
||||
"data": fetched["data"],
|
||||
},
|
||||
"title": fetched["title"],
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
"retrieved_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
summary = fetched["data"][:_MAX_FETCH_CHARS]
|
||||
result_block_type = WEB_FETCH_TOOL_RESULT
|
||||
except Exception as error:
|
||||
fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
|
||||
outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
|
||||
result_block_type = _result_block_for_tool[tool_name]
|
||||
result_content = {
|
||||
"type": _error_payload_type_for_tool[tool_name],
|
||||
"error_code": "unavailable",
|
||||
}
|
||||
summary = outbound._web_tool_client_error_summary(
|
||||
tool_name, error, verbose=verbose_client_errors
|
||||
)
|
||||
|
||||
output_tokens = max(1, len(summary) // 4)
|
||||
|
||||
yield format_sse_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": {
|
||||
"type": result_block_type,
|
||||
"tool_use_id": tool_id,
|
||||
"content": result_content,
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 1}
|
||||
)
|
||||
# Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
|
||||
# not eager `text` on `content_block_start`).
|
||||
yield format_sse_event(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 2,
|
||||
"delta": {"type": "text_delta", "text": summary},
|
||||
},
|
||||
)
|
||||
yield format_sse_event(
|
||||
"content_block_stop", {"type": "content_block_stop", "index": 2}
|
||||
)
|
||||
yield format_sse_event(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
|
||||
"usage": {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"server_tool_use": {usage_key: 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
yield format_sse_event("message_stop", {"type": "message_stop"})
|
||||
|
|
@ -30,7 +30,8 @@ def serve() -> None:
|
|||
settings = get_settings()
|
||||
try:
|
||||
uvicorn.run(
|
||||
"api.app:app",
|
||||
"api.app:create_app",
|
||||
factory=True,
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
log_level="debug",
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@ class CLISessionManager:
|
|||
allowed_dirs: list[str] | None = None,
|
||||
plans_directory: str | None = None,
|
||||
claude_bin: str = "claude",
|
||||
*,
|
||||
log_raw_cli_diagnostics: bool = False,
|
||||
log_messaging_error_details: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
|
|
@ -44,6 +47,8 @@ class CLISessionManager:
|
|||
self.allowed_dirs = allowed_dirs or []
|
||||
self.plans_directory = plans_directory
|
||||
self.claude_bin = claude_bin
|
||||
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
|
||||
self._log_messaging_error_details = log_messaging_error_details
|
||||
|
||||
self._sessions: dict[str, CLISession] = {}
|
||||
self._pending_sessions: dict[str, CLISession] = {}
|
||||
|
|
@ -79,6 +84,7 @@ class CLISessionManager:
|
|||
allowed_dirs=self.allowed_dirs,
|
||||
plans_directory=self.plans_directory,
|
||||
claude_bin=self.claude_bin,
|
||||
log_raw_cli_diagnostics=self._log_raw_cli_diagnostics,
|
||||
)
|
||||
self._pending_sessions[temp_id] = new_session
|
||||
logger.info(f"Created new session: {temp_id}")
|
||||
|
|
@ -130,7 +136,17 @@ class CLISessionManager:
|
|||
try:
|
||||
await session.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping session: {e}")
|
||||
if self._log_messaging_error_details:
|
||||
logger.error(
|
||||
"Error stopping session: {}: {}",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Error stopping session: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
self._sessions.clear()
|
||||
self._pending_sessions.clear()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from loguru import logger
|
|||
|
||||
from .process_registry import register_pid, unregister_pid
|
||||
|
||||
# Cap stderr capture so a runaway child cannot exhaust memory; pipe is still drained.
|
||||
_MAX_STDERR_CAPTURE_BYTES = 256 * 1024
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ClaudeCliConfig:
|
||||
|
|
@ -33,6 +36,8 @@ class CLISession:
|
|||
allowed_dirs: list[str] | None = None,
|
||||
plans_directory: str | None = None,
|
||||
claude_bin: str = "claude",
|
||||
*,
|
||||
log_raw_cli_diagnostics: bool = False,
|
||||
):
|
||||
self.config = ClaudeCliConfig(
|
||||
workspace_path=os.path.normpath(os.path.abspath(workspace_path)),
|
||||
|
|
@ -46,11 +51,40 @@ class CLISession:
|
|||
self.allowed_dirs = self.config.allowed_dirs
|
||||
self.plans_directory = self.config.plans_directory
|
||||
self.claude_bin = self.config.claude_bin
|
||||
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
|
||||
self.process: asyncio.subprocess.Process | None = None
|
||||
self.current_session_id: str | None = None
|
||||
self._is_busy = False
|
||||
self._cli_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
async def _drain_stderr_bounded(
|
||||
process: asyncio.subprocess.Process,
|
||||
*,
|
||||
max_bytes: int = _MAX_STDERR_CAPTURE_BYTES,
|
||||
) -> bytes:
|
||||
"""Read stderr concurrently with stdout to avoid subprocess pipe deadlocks.
|
||||
|
||||
Retains at most ``max_bytes`` for logging; any excess is discarded, but
|
||||
the pipe is read until EOF so a noisy child cannot fill the buffer and
|
||||
block forever.
|
||||
"""
|
||||
if not process.stderr:
|
||||
return b""
|
||||
parts: list[bytes] = []
|
||||
received = 0
|
||||
while True:
|
||||
chunk = await process.stderr.read(65_536)
|
||||
if not chunk:
|
||||
break
|
||||
if received < max_bytes:
|
||||
take = min(len(chunk), max_bytes - received)
|
||||
if take:
|
||||
parts.append(chunk[:take])
|
||||
received += take
|
||||
# If already at cap, keep reading and discarding until EOF.
|
||||
return b"".join(parts)
|
||||
|
||||
@property
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if a task is currently running."""
|
||||
|
|
@ -140,6 +174,11 @@ class CLISession:
|
|||
|
||||
session_id_extracted = False
|
||||
buffer = bytearray()
|
||||
stderr_task: asyncio.Task[bytes] | None = None
|
||||
if self.process.stderr:
|
||||
stderr_task = asyncio.create_task(
|
||||
self._drain_stderr_bounded(self.process)
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
|
@ -179,21 +218,25 @@ class CLISession:
|
|||
except asyncio.CancelledError:
|
||||
# Cancelling the handler task should not leave a Claude CLI
|
||||
# subprocess running in the background.
|
||||
try:
|
||||
await asyncio.shield(self.stop())
|
||||
finally:
|
||||
raise
|
||||
finally:
|
||||
stderr_bytes = b""
|
||||
if stderr_task is not None:
|
||||
stderr_bytes = await stderr_task
|
||||
|
||||
stderr_text = None
|
||||
if self.process.stderr:
|
||||
stderr_output = await self.process.stderr.read()
|
||||
if stderr_output:
|
||||
stderr_text = stderr_output.decode(
|
||||
"utf-8", errors="replace"
|
||||
).strip()
|
||||
logger.error(f"Claude CLI Stderr: {stderr_text}")
|
||||
# Yield stderr as error event so it shows in UI
|
||||
if stderr_bytes:
|
||||
stderr_text = stderr_bytes.decode("utf-8", errors="replace").strip()
|
||||
if stderr_text:
|
||||
if self._log_raw_cli_diagnostics:
|
||||
logger.error("Claude CLI stderr: {}", stderr_text)
|
||||
else:
|
||||
logger.error(
|
||||
"Claude CLI stderr: bytes={} text_chars={}",
|
||||
len(stderr_bytes),
|
||||
len(stderr_text),
|
||||
)
|
||||
logger.info("CLI_SESSION: Yielding error event from stderr")
|
||||
yield {"type": "error", "error": {"message": stderr_text}}
|
||||
|
||||
|
|
@ -230,7 +273,10 @@ class CLISession:
|
|||
|
||||
yield event
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"Non-JSON output: {line_str}")
|
||||
if self._log_raw_cli_diagnostics:
|
||||
logger.debug("Non-JSON output: {}", line_str)
|
||||
else:
|
||||
logger.debug("Non-JSON CLI line: char_len={}", len(line_str))
|
||||
yield {"type": "raw", "content": line_str}
|
||||
|
||||
def _extract_session_id(self, event: Any) -> str | None:
|
||||
|
|
@ -273,6 +319,16 @@ class CLISession:
|
|||
unregister_pid(self.process.pid)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping process: {e}")
|
||||
if self._log_raw_cli_diagnostics:
|
||||
logger.error(
|
||||
"Error stopping process: {}: {}",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"Error stopping process: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
return False
|
||||
return False
|
||||
|
|
|
|||
10
config/constants.py
Normal file
10
config/constants.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Shared defaults used by config models and provider adapters."""
|
||||
|
||||
# HTTP client connect timeout (seconds). Keep aligned with README.md and .env.example.
|
||||
HTTP_CONNECT_TIMEOUT_DEFAULT = 10.0
|
||||
|
||||
# Anthropic Messages API default when the client omits max_tokens.
|
||||
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS = 81920
|
||||
|
||||
# Max bytes read from a non-200 native messages response when verbose error logging is on.
|
||||
NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES = 4096
|
||||
|
|
@ -8,6 +8,7 @@ included at top level for easy grep/filter.
|
|||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -17,6 +18,22 @@ _configured = False
|
|||
# Context keys we promote to top-level JSON for traceability
|
||||
_CONTEXT_KEYS = ("request_id", "node_id", "chat_id")
|
||||
|
||||
_TELEGRAM_BOT_RE = re.compile(
|
||||
r"(https?://api\.telegram\.org/)bot([0-9]+:[A-Za-z0-9_-]+)(/?)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Authorization: Bearer <token> (HTTP client / proxy debug lines)
|
||||
_AUTH_BEARER_RE = re.compile(
|
||||
r"(\bAuthorization\s*:\s*Bearer\s+)([^\s'\"]+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _redact_sensitive_substrings(message: str) -> str:
|
||||
"""Remove obvious API tokens and secrets before JSON log line emission."""
|
||||
text = _TELEGRAM_BOT_RE.sub(r"\1bot<redacted>\3", message)
|
||||
return _AUTH_BEARER_RE.sub(r"\1<redacted>", text)
|
||||
|
||||
|
||||
def _serialize_with_context(record) -> str:
|
||||
"""Format record as JSON with context vars at top level.
|
||||
|
|
@ -26,7 +43,7 @@ def _serialize_with_context(record) -> str:
|
|||
out = {
|
||||
"time": str(record["time"]),
|
||||
"level": record["level"].name,
|
||||
"message": record["message"],
|
||||
"message": _redact_sensitive_substrings(str(record["message"])),
|
||||
"module": record["name"],
|
||||
"function": record["function"],
|
||||
"line": record["line"],
|
||||
|
|
@ -57,11 +74,16 @@ class InterceptHandler(logging.Handler):
|
|||
)
|
||||
|
||||
|
||||
def configure_logging(log_file: str, *, force: bool = False) -> None:
|
||||
def configure_logging(
|
||||
log_file: str, *, force: bool = False, verbose_third_party: bool = False
|
||||
) -> None:
|
||||
"""Configure loguru with JSON output to log_file and intercept stdlib logging.
|
||||
|
||||
Idempotent: skips if already configured (e.g. hot reload).
|
||||
Use force=True to reconfigure (e.g. in tests with a different log path).
|
||||
|
||||
When ``verbose_third_party`` is false, noisy HTTP and Telegram loggers are capped
|
||||
at WARNING unless explicitly configured otherwise.
|
||||
"""
|
||||
global _configured
|
||||
if _configured and not force:
|
||||
|
|
@ -88,3 +110,16 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
|
|||
intercept = InterceptHandler()
|
||||
logging.root.handlers = [intercept]
|
||||
logging.root.setLevel(logging.DEBUG)
|
||||
|
||||
third_party = (
|
||||
"httpx",
|
||||
"httpcore",
|
||||
"httpcore.http11",
|
||||
"httpcore.connection",
|
||||
"telegram",
|
||||
"telegram.ext",
|
||||
)
|
||||
for name in third_party:
|
||||
logging.getLogger(name).setLevel(
|
||||
logging.WARNING if not verbose_third_party else logging.NOTSET
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
|
||||
|
||||
class NimSettings(BaseModel):
|
||||
"""Fixed NVIDIA NIM settings (not configurable via env)."""
|
||||
|
|
@ -14,7 +16,9 @@ class NimSettings(BaseModel):
|
|||
)
|
||||
top_k: int = -1
|
||||
max_tokens: int = Field(
|
||||
81920, ge=1, description="Maximum number of tokens in output."
|
||||
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
ge=1,
|
||||
description="Maximum number of tokens in output.",
|
||||
)
|
||||
presence_penalty: float = Field(0.0, ge=-2.0, le=2.0)
|
||||
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
|
||||
|
|
@ -68,7 +72,7 @@ class NimSettings(BaseModel):
|
|||
return field_defaults.get(key, 1.0)
|
||||
try:
|
||||
val = float(v)
|
||||
except Exception as err:
|
||||
except (TypeError, ValueError) as err:
|
||||
raise ValueError(
|
||||
f"{info.field_name} must be a float. Got {type(v).__name__}."
|
||||
) from err
|
||||
|
|
@ -78,15 +82,15 @@ class NimSettings(BaseModel):
|
|||
@classmethod
|
||||
def validate_int_fields(cls, v, info: ValidationInfo):
|
||||
field_defaults = {
|
||||
"max_tokens": 81920,
|
||||
"max_tokens": ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
"min_tokens": 0,
|
||||
}
|
||||
if v is None or v == "":
|
||||
key = info.field_name or "max_tokens"
|
||||
return field_defaults.get(key, 81920)
|
||||
return field_defaults.get(key, ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS)
|
||||
try:
|
||||
val = int(v)
|
||||
except Exception as err:
|
||||
except (TypeError, ValueError) as err:
|
||||
raise ValueError(
|
||||
f"{info.field_name} must be an int. Got {type(v).__name__}."
|
||||
) from err
|
||||
|
|
@ -99,7 +103,7 @@ class NimSettings(BaseModel):
|
|||
return None
|
||||
try:
|
||||
return int(v)
|
||||
except Exception as err:
|
||||
except (TypeError, ValueError) as err:
|
||||
raise ValueError(
|
||||
f"{info.field_name} must be an int or empty/None."
|
||||
) from err
|
||||
|
|
|
|||
108
config/provider_catalog.py
Normal file
108
config/provider_catalog.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Neutral provider catalog: IDs, credentials, defaults, proxy and capability metadata.
|
||||
|
||||
Adapter factories live in :mod:`providers.registry`; this module stays free of
|
||||
provider implementation imports (see contract tests).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
TransportType = Literal["openai_chat", "anthropic_messages"]
|
||||
|
||||
# Default upstream base URLs (also re-exported via :mod:`providers.defaults`)
|
||||
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
|
||||
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
|
||||
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
|
||||
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
|
||||
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
|
||||
OLLAMA_DEFAULT_BASE = "http://localhost:11434"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProviderDescriptor:
|
||||
"""Metadata for building :class:`~providers.base.ProviderConfig` and factory wiring."""
|
||||
|
||||
provider_id: str
|
||||
transport_type: TransportType
|
||||
capabilities: tuple[str, ...]
|
||||
credential_env: str | None = None
|
||||
credential_url: str | None = None
|
||||
credential_attr: str | None = None
|
||||
static_credential: str | None = None
|
||||
default_base_url: str | None = None
|
||||
base_url_attr: str | None = None
|
||||
proxy_attr: str | None = None
|
||||
|
||||
|
||||
PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
|
||||
"nvidia_nim": ProviderDescriptor(
|
||||
provider_id="nvidia_nim",
|
||||
transport_type="openai_chat",
|
||||
credential_env="NVIDIA_NIM_API_KEY",
|
||||
credential_url="https://build.nvidia.com/settings/api-keys",
|
||||
credential_attr="nvidia_nim_api_key",
|
||||
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
|
||||
proxy_attr="nvidia_nim_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
),
|
||||
"open_router": ProviderDescriptor(
|
||||
provider_id="open_router",
|
||||
transport_type="anthropic_messages",
|
||||
credential_env="OPENROUTER_API_KEY",
|
||||
credential_url="https://openrouter.ai/keys",
|
||||
credential_attr="open_router_api_key",
|
||||
default_base_url=OPENROUTER_DEFAULT_BASE,
|
||||
proxy_attr="open_router_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
|
||||
),
|
||||
"deepseek": ProviderDescriptor(
|
||||
provider_id="deepseek",
|
||||
transport_type="openai_chat",
|
||||
credential_env="DEEPSEEK_API_KEY",
|
||||
credential_url="https://platform.deepseek.com/api_keys",
|
||||
credential_attr="deepseek_api_key",
|
||||
default_base_url=DEEPSEEK_DEFAULT_BASE,
|
||||
capabilities=("chat", "streaming", "thinking"),
|
||||
),
|
||||
"lmstudio": ProviderDescriptor(
|
||||
provider_id="lmstudio",
|
||||
transport_type="anthropic_messages",
|
||||
static_credential="lm-studio",
|
||||
default_base_url=LMSTUDIO_DEFAULT_BASE,
|
||||
base_url_attr="lm_studio_base_url",
|
||||
proxy_attr="lmstudio_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||
),
|
||||
"llamacpp": ProviderDescriptor(
|
||||
provider_id="llamacpp",
|
||||
transport_type="anthropic_messages",
|
||||
static_credential="llamacpp",
|
||||
default_base_url=LLAMACPP_DEFAULT_BASE,
|
||||
base_url_attr="llamacpp_base_url",
|
||||
proxy_attr="llamacpp_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||
),
|
||||
"ollama": ProviderDescriptor(
|
||||
provider_id="ollama",
|
||||
transport_type="anthropic_messages",
|
||||
static_credential="ollama",
|
||||
default_base_url=OLLAMA_DEFAULT_BASE,
|
||||
base_url_attr="ollama_base_url",
|
||||
capabilities=(
|
||||
"chat",
|
||||
"streaming",
|
||||
"tools",
|
||||
"thinking",
|
||||
"native_anthropic",
|
||||
"local",
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
# Order matches docs / historical error text; must match PROVIDER_CATALOG keys.
|
||||
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = tuple(PROVIDER_CATALOG.keys())
|
||||
|
||||
if len(set(SUPPORTED_PROVIDER_IDS)) != len(SUPPORTED_PROVIDER_IDS):
|
||||
raise AssertionError("Duplicate provider ids in PROVIDER_CATALOG key order")
|
||||
|
|
@ -1,18 +1,7 @@
|
|||
"""Canonical list of model provider type prefixes (provider_id values).
|
||||
|
||||
`providers.registry.PROVIDER_DESCRIPTORS` is the full metadata source; this
|
||||
module holds the id set for config validation and must stay in sync
|
||||
(registries assert in `providers.registry`).
|
||||
"""
|
||||
"""Canonical provider id tuple (re-exported from the provider catalog)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Order matches docs / historical error text; must match PROVIDER_DESCRIPTORS keys.
|
||||
SUPPORTED_PROVIDER_IDS: tuple[str, ...] = (
|
||||
"nvidia_nim",
|
||||
"open_router",
|
||||
"deepseek",
|
||||
"lmstudio",
|
||||
"llamacpp",
|
||||
"ollama",
|
||||
)
|
||||
from .provider_catalog import SUPPORTED_PROVIDER_IDS
|
||||
|
||||
__all__ = ("SUPPORTED_PROVIDER_IDS",)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from dotenv import dotenv_values
|
|||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .constants import HTTP_CONNECT_TIMEOUT_DEFAULT
|
||||
from .nim import NimSettings
|
||||
from .provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
|
||||
|
|
@ -105,6 +106,12 @@ class Settings(BaseSettings):
|
|||
messaging_platform: str = Field(
|
||||
default="discord", validation_alias="MESSAGING_PLATFORM"
|
||||
)
|
||||
messaging_rate_limit: int = Field(
|
||||
default=1, validation_alias="MESSAGING_RATE_LIMIT"
|
||||
)
|
||||
messaging_rate_window: float = Field(
|
||||
default=1.0, validation_alias="MESSAGING_RATE_WINDOW"
|
||||
)
|
||||
|
||||
# ==================== NVIDIA NIM Config ====================
|
||||
nvidia_nim_api_key: str = ""
|
||||
|
|
@ -173,7 +180,8 @@ class Settings(BaseSettings):
|
|||
default=10.0, validation_alias="HTTP_WRITE_TIMEOUT"
|
||||
)
|
||||
http_connect_timeout: float = Field(
|
||||
default=2.0, validation_alias="HTTP_CONNECT_TIMEOUT"
|
||||
default=HTTP_CONNECT_TIMEOUT_DEFAULT,
|
||||
validation_alias="HTTP_CONNECT_TIMEOUT",
|
||||
)
|
||||
|
||||
# ==================== Fast Prefix Detection ====================
|
||||
|
|
@ -185,6 +193,51 @@ class Settings(BaseSettings):
|
|||
enable_suggestion_mode_skip: bool = True
|
||||
enable_filepath_extraction_mock: bool = True
|
||||
|
||||
# ==================== Local web server tools (web_search / web_fetch) ====================
|
||||
# Off by default: these tools perform outbound HTTP from the proxy (SSRF risk).
|
||||
enable_web_server_tools: bool = Field(
|
||||
default=False, validation_alias="ENABLE_WEB_SERVER_TOOLS"
|
||||
)
|
||||
# Comma-separated URL schemes allowed for web_fetch (default: http,https).
|
||||
web_fetch_allowed_schemes: str = Field(
|
||||
default="http,https", validation_alias="WEB_FETCH_ALLOWED_SCHEMES"
|
||||
)
|
||||
# When true, skip private/loopback/link-local IP blocking for web_fetch (lab only).
|
||||
web_fetch_allow_private_networks: bool = Field(
|
||||
default=False, validation_alias="WEB_FETCH_ALLOW_PRIVATE_NETWORKS"
|
||||
)
|
||||
|
||||
# ==================== Debug / diagnostic logging (avoid sensitive content) ====================
|
||||
# When false (default), API and SSE helpers log only metadata (counts, lengths, ids).
|
||||
log_raw_api_payloads: bool = Field(
|
||||
default=False, validation_alias="LOG_RAW_API_PAYLOADS"
|
||||
)
|
||||
log_raw_sse_events: bool = Field(
|
||||
default=False, validation_alias="LOG_RAW_SSE_EVENTS"
|
||||
)
|
||||
# When false (default), unhandled exceptions log only type + route metadata (no message/traceback).
|
||||
log_api_error_tracebacks: bool = Field(
|
||||
default=False, validation_alias="LOG_API_ERROR_TRACEBACKS"
|
||||
)
|
||||
# When false (default), messaging logs omit text/transcription previews (metadata only).
|
||||
log_raw_messaging_content: bool = Field(
|
||||
default=False, validation_alias="LOG_RAW_MESSAGING_CONTENT"
|
||||
)
|
||||
# When true, log full Claude CLI stderr, non-JSON lines, and parser error text.
|
||||
log_raw_cli_diagnostics: bool = Field(
|
||||
default=False, validation_alias="LOG_RAW_CLI_DIAGNOSTICS"
|
||||
)
|
||||
# When true, log exception text / CLI error strings in messaging (may leak user content).
|
||||
log_messaging_error_details: bool = Field(
|
||||
default=False, validation_alias="LOG_MESSAGING_ERROR_DETAILS"
|
||||
)
|
||||
debug_platform_edits: bool = Field(
|
||||
default=False, validation_alias="DEBUG_PLATFORM_EDITS"
|
||||
)
|
||||
debug_subagent_stack: bool = Field(
|
||||
default=False, validation_alias="DEBUG_SUBAGENT_STACK"
|
||||
)
|
||||
|
||||
# ==================== NIM Settings ====================
|
||||
nim: NimSettings = Field(default_factory=NimSettings)
|
||||
|
||||
|
|
@ -215,6 +268,9 @@ class Settings(BaseSettings):
|
|||
claude_workspace: str = "./agent_workspace"
|
||||
allowed_dir: str = ""
|
||||
claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN")
|
||||
max_message_log_entries_per_chat: int | None = Field(
|
||||
default=None, validation_alias="MAX_MESSAGE_LOG_ENTRIES_PER_CHAT"
|
||||
)
|
||||
|
||||
# ==================== Server ====================
|
||||
host: str = "0.0.0.0"
|
||||
|
|
@ -254,6 +310,13 @@ class Settings(BaseSettings):
|
|||
return None
|
||||
return v
|
||||
|
||||
@field_validator("max_message_log_entries_per_chat", mode="before")
|
||||
@classmethod
|
||||
def parse_optional_log_cap(cls, v: Any) -> Any:
|
||||
if v == "" or v is None:
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator("whisper_device")
|
||||
@classmethod
|
||||
def validate_whisper_device(cls, v: str) -> str:
|
||||
|
|
@ -272,6 +335,33 @@ class Settings(BaseSettings):
|
|||
)
|
||||
return v
|
||||
|
||||
@field_validator("messaging_rate_limit")
|
||||
@classmethod
|
||||
def validate_messaging_rate_limit(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("messaging_rate_limit must be > 0")
|
||||
return v
|
||||
|
||||
@field_validator("messaging_rate_window")
|
||||
@classmethod
|
||||
def validate_messaging_rate_window(cls, v: float) -> float:
|
||||
if v <= 0:
|
||||
raise ValueError("messaging_rate_window must be > 0")
|
||||
return float(v)
|
||||
|
||||
@field_validator("web_fetch_allowed_schemes")
|
||||
@classmethod
|
||||
def validate_web_fetch_allowed_schemes(cls, v: str) -> str:
|
||||
schemes = [part.strip().lower() for part in v.split(",") if part.strip()]
|
||||
if not schemes:
|
||||
raise ValueError("web_fetch_allowed_schemes must list at least one scheme")
|
||||
for scheme in schemes:
|
||||
if not scheme.isascii() or not scheme.isalpha():
|
||||
raise ValueError(
|
||||
f"Invalid URL scheme in web_fetch_allowed_schemes: {scheme!r}"
|
||||
)
|
||||
return ",".join(schemes)
|
||||
|
||||
@field_validator("ollama_base_url")
|
||||
@classmethod
|
||||
def validate_ollama_base_url(cls, v: str) -> str:
|
||||
|
|
@ -329,12 +419,12 @@ class Settings(BaseSettings):
|
|||
@property
|
||||
def provider_type(self) -> str:
|
||||
"""Extract provider type from the default model string."""
|
||||
return self.model.split("/", 1)[0]
|
||||
return Settings.parse_provider_type(self.model)
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Extract the actual model name from the default model string."""
|
||||
return self.model.split("/", 1)[1]
|
||||
return Settings.parse_model_name(self.model)
|
||||
|
||||
def resolve_model(self, claude_model_name: str) -> str:
|
||||
"""Resolve a Claude model name to the configured provider/model string.
|
||||
|
|
@ -362,6 +452,14 @@ class Settings(BaseSettings):
|
|||
return self.enable_sonnet_thinking
|
||||
return self.enable_model_thinking
|
||||
|
||||
def web_fetch_allowed_scheme_set(self) -> frozenset[str]:
|
||||
"""Return normalized schemes allowed for web_fetch."""
|
||||
return frozenset(
|
||||
part.strip().lower()
|
||||
for part in self.web_fetch_allowed_schemes.split(",")
|
||||
if part.strip()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_provider_type(model_string: str) -> str:
|
||||
"""Extract provider type from any 'provider/model' string."""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,19 @@
|
|||
"""Anthropic protocol helpers shared across API, providers, and integrations."""
|
||||
|
||||
from .content import extract_text_from_content, get_block_attr, get_block_type
|
||||
from .conversion import AnthropicToOpenAIConverter, build_base_request_body
|
||||
from .errors import append_request_id, get_user_facing_error_message
|
||||
from .sse import ContentBlockManager, SSEBuilder, map_stop_reason
|
||||
from .conversion import (
|
||||
AnthropicToOpenAIConverter,
|
||||
OpenAIConversionError,
|
||||
build_base_request_body,
|
||||
)
|
||||
from .errors import (
|
||||
append_request_id,
|
||||
format_user_error_preview,
|
||||
get_user_facing_error_message,
|
||||
)
|
||||
from .native_messages_request import sanitize_native_messages_thinking_policy
|
||||
from .provider_stream_error import iter_provider_stream_error_sse_events
|
||||
from .sse import ContentBlockManager, SSEBuilder, format_sse_event, map_stop_reason
|
||||
from .thinking import ContentChunk, ContentType, ThinkTagParser
|
||||
from .tokens import get_token_count
|
||||
from .tools import HeuristicToolParser
|
||||
|
|
@ -15,15 +25,20 @@ __all__ = [
|
|||
"ContentChunk",
|
||||
"ContentType",
|
||||
"HeuristicToolParser",
|
||||
"OpenAIConversionError",
|
||||
"SSEBuilder",
|
||||
"ThinkTagParser",
|
||||
"append_request_id",
|
||||
"build_base_request_body",
|
||||
"extract_text_from_content",
|
||||
"format_sse_event",
|
||||
"format_user_error_preview",
|
||||
"get_block_attr",
|
||||
"get_block_type",
|
||||
"get_token_count",
|
||||
"get_user_facing_error_message",
|
||||
"iter_provider_stream_error_sse_events",
|
||||
"map_stop_reason",
|
||||
"sanitize_native_messages_thinking_policy",
|
||||
"set_if_not_none",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,10 +3,34 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .content import get_block_attr, get_block_type
|
||||
from .utils import set_if_not_none
|
||||
|
||||
|
||||
class OpenAIConversionError(Exception):
|
||||
"""Raised when Anthropic content cannot be converted to OpenAI chat without data loss."""
|
||||
|
||||
|
||||
def _openai_reject_native_only_top_level_fields(request_data: Any) -> None:
|
||||
"""OpenAI chat providers may only convert known top-level request fields.
|
||||
|
||||
First-class model fields (e.g. ``context_management``) are not forwarded to
|
||||
the OpenAI API but are allowed so clients do not hit spurious 400s.
|
||||
Unknown extra keys (``__pydantic_extra__``) are still rejected.
|
||||
"""
|
||||
if not isinstance(request_data, BaseModel):
|
||||
return
|
||||
extra = getattr(request_data, "__pydantic_extra__", None)
|
||||
if not extra:
|
||||
return
|
||||
raise OpenAIConversionError(
|
||||
"OpenAI chat conversion does not support these top-level request fields: "
|
||||
f"{sorted(str(k) for k in extra)}. Use a native Anthropic transport provider."
|
||||
)
|
||||
|
||||
|
||||
def _tool_name(tool: Any) -> str:
|
||||
return str(getattr(tool, "name", "") or "")
|
||||
|
||||
|
|
@ -18,6 +42,27 @@ def _tool_input_schema(tool: Any) -> dict[str, Any]:
|
|||
return {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def _serialize_tool_result_content(tool_content: Any) -> str:
|
||||
"""Serialize tool_result content for OpenAI ``role: tool`` messages (stable JSON for structured values)."""
|
||||
if tool_content is None:
|
||||
return ""
|
||||
if isinstance(tool_content, str):
|
||||
return tool_content
|
||||
if isinstance(tool_content, dict):
|
||||
return json.dumps(tool_content, ensure_ascii=False)
|
||||
if isinstance(tool_content, list):
|
||||
parts: list[str] = []
|
||||
for item in tool_content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(str(item.get("text", "")))
|
||||
elif isinstance(item, dict):
|
||||
parts.append(json.dumps(item, ensure_ascii=False))
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return "\n".join(parts)
|
||||
return str(tool_content)
|
||||
|
||||
|
||||
class AnthropicToOpenAIConverter:
|
||||
"""Convert Anthropic message format to OpenAI-compatible format."""
|
||||
|
||||
|
|
@ -64,20 +109,38 @@ class AnthropicToOpenAIConverter:
|
|||
content_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
seen_tool_use = False
|
||||
|
||||
for block in content:
|
||||
block_type = get_block_type(block)
|
||||
|
||||
if block_type == "text":
|
||||
if seen_tool_use:
|
||||
raise OpenAIConversionError(
|
||||
"OpenAI chat conversion does not support assistant text after "
|
||||
"tool_use in the same message; split the transcript or use a "
|
||||
"native Anthropic provider."
|
||||
)
|
||||
content_parts.append(get_block_attr(block, "text", ""))
|
||||
elif block_type == "thinking":
|
||||
if not include_thinking:
|
||||
continue
|
||||
if seen_tool_use:
|
||||
raise OpenAIConversionError(
|
||||
"OpenAI chat conversion does not support assistant thinking after "
|
||||
"tool_use in the same message; split the transcript or use a "
|
||||
"native Anthropic provider."
|
||||
)
|
||||
thinking = get_block_attr(block, "thinking", "")
|
||||
content_parts.append(f"<think>\n{thinking}\n</think>")
|
||||
if include_reasoning_content:
|
||||
thinking_parts.append(thinking)
|
||||
elif block_type == "redacted_thinking":
|
||||
# Opaque provider continuation data; do not materialize as model-visible text
|
||||
# or reasoning_content for OpenAI chat upstreams.
|
||||
continue
|
||||
elif block_type == "tool_use":
|
||||
seen_tool_use = True
|
||||
tool_input = get_block_attr(block, "input", {})
|
||||
tool_calls.append(
|
||||
{
|
||||
|
|
@ -91,6 +154,19 @@ class AnthropicToOpenAIConverter:
|
|||
},
|
||||
}
|
||||
)
|
||||
elif block_type == "image":
|
||||
raise OpenAIConversionError(
|
||||
"Assistant image blocks are not supported for OpenAI chat conversion."
|
||||
)
|
||||
elif block_type in (
|
||||
"server_tool_use",
|
||||
"web_search_tool_result",
|
||||
"web_fetch_tool_result",
|
||||
):
|
||||
raise OpenAIConversionError(
|
||||
"OpenAI chat conversion does not support Anthropic server tool blocks "
|
||||
f"({block_type!r} in an assistant message). Use a native Anthropic transport provider."
|
||||
)
|
||||
|
||||
content_str = "\n\n".join(content_parts)
|
||||
if not content_str and not tool_calls:
|
||||
|
|
@ -122,21 +198,21 @@ class AnthropicToOpenAIConverter:
|
|||
|
||||
if block_type == "text":
|
||||
text_parts.append(get_block_attr(block, "text", ""))
|
||||
elif block_type == "image":
|
||||
raise OpenAIConversionError(
|
||||
"User message image blocks are not supported for OpenAI chat "
|
||||
"conversion; use a vision-capable native Anthropic provider or "
|
||||
"extend the converter."
|
||||
)
|
||||
elif block_type == "tool_result":
|
||||
flush_text()
|
||||
tool_content = get_block_attr(block, "content", "")
|
||||
if isinstance(tool_content, list):
|
||||
tool_content = "\n".join(
|
||||
item.get("text", str(item))
|
||||
if isinstance(item, dict)
|
||||
else str(item)
|
||||
for item in tool_content
|
||||
)
|
||||
serialized = _serialize_tool_result_content(tool_content)
|
||||
result.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": get_block_attr(block, "tool_use_id"),
|
||||
"content": str(tool_content) if tool_content else "",
|
||||
"content": serialized if serialized else "",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -199,6 +275,7 @@ def build_base_request_body(
|
|||
include_reasoning_content: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the common parts of an OpenAI-format request body."""
|
||||
_openai_reject_native_only_top_level_fields(request_data)
|
||||
messages = AnthropicToOpenAIConverter.convert_messages(
|
||||
request_data.messages,
|
||||
include_thinking=include_thinking,
|
||||
|
|
|
|||
97
core/anthropic/emitted_sse_tracker.py
Normal file
97
core/anthropic/emitted_sse_tracker.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Track content-block state for native Anthropic SSE strings we emit to clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
from core.anthropic.sse import SSEBuilder, format_sse_event
|
||||
from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines
|
||||
|
||||
|
||||
class EmittedNativeSseTracker:
|
||||
"""Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buf = ""
|
||||
self._open_stack: list[int] = []
|
||||
self._max_index = -1
|
||||
self.message_id: str | None = None
|
||||
self.model: str = ""
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Record SSE frames completed by ``chunk`` (handles splitting across reads)."""
|
||||
self._buf += chunk
|
||||
while True:
|
||||
sep = self._buf.find("\n\n")
|
||||
if sep < 0:
|
||||
break
|
||||
frame = self._buf[:sep]
|
||||
self._buf = self._buf[sep + 2 :]
|
||||
if not frame.strip():
|
||||
continue
|
||||
for event in parse_sse_lines(frame.splitlines()):
|
||||
self._observe(event)
|
||||
|
||||
def _observe(self, event: SSEEvent) -> None:
|
||||
if event.event == "message_start":
|
||||
message = event.data.get("message")
|
||||
if isinstance(message, dict):
|
||||
mid = message.get("id")
|
||||
if isinstance(mid, str) and mid:
|
||||
self.message_id = mid
|
||||
model = message.get("model")
|
||||
if isinstance(model, str) and model:
|
||||
self.model = model
|
||||
return
|
||||
|
||||
if event.event == "content_block_start":
|
||||
idx = event_index(event)
|
||||
self._max_index = max(self._max_index, idx)
|
||||
self._open_stack.append(idx)
|
||||
return
|
||||
|
||||
if event.event == "content_block_stop":
|
||||
idx = event_index(event)
|
||||
if self._open_stack and self._open_stack[-1] == idx:
|
||||
self._open_stack.pop()
|
||||
else:
|
||||
with suppress(ValueError):
|
||||
self._open_stack.remove(idx)
|
||||
|
||||
def next_content_index(self) -> int:
|
||||
"""Next unused content block index based on emitted starts."""
|
||||
return self._max_index + 1
|
||||
|
||||
def iter_close_unclosed_blocks(self) -> Iterator[str]:
|
||||
"""Yield ``content_block_stop`` events for blocks that were started but not stopped."""
|
||||
while self._open_stack:
|
||||
idx = self._open_stack.pop()
|
||||
yield format_sse_event(
|
||||
"content_block_stop",
|
||||
{"type": "content_block_stop", "index": idx},
|
||||
)
|
||||
|
||||
def iter_midstream_error_tail(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
request: Any,
|
||||
input_tokens: int,
|
||||
log_raw_sse_events: bool,
|
||||
) -> Iterator[str]:
|
||||
"""Close dangling blocks, emit a text error block at a fresh index, then message tail."""
|
||||
mid = self.message_id or f"msg_{uuid.uuid4()}"
|
||||
model = self.model or (getattr(request, "model", "") or "")
|
||||
sse = SSEBuilder(
|
||||
mid,
|
||||
model,
|
||||
input_tokens,
|
||||
log_raw_events=log_raw_sse_events,
|
||||
)
|
||||
sse.blocks.next_index = self.next_content_index()
|
||||
yield from sse.emit_error(error_message)
|
||||
yield sse.message_delta("end_turn", 1)
|
||||
yield sse.message_stop()
|
||||
|
|
@ -9,11 +9,12 @@ def get_user_facing_error_message(
|
|||
*,
|
||||
read_timeout_s: float | None = None,
|
||||
) -> str:
|
||||
"""Return a readable, non-empty error message for users."""
|
||||
message = str(e).strip()
|
||||
if message:
|
||||
return message
|
||||
"""Return a readable, non-empty error message for users.
|
||||
|
||||
Known transport and OpenAI SDK exception types are mapped to stable wording
|
||||
before falling back to ``str(e)``, so empty or noisy SDK messages do not skip
|
||||
the mapped path.
|
||||
"""
|
||||
if isinstance(e, httpx.ReadTimeout):
|
||||
if read_timeout_s is not None:
|
||||
return f"Provider request timed out after {read_timeout_s:g}s."
|
||||
|
|
@ -25,13 +26,20 @@ def get_user_facing_error_message(
|
|||
return f"Provider request timed out after {read_timeout_s:g}s."
|
||||
return "Request timed out."
|
||||
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
return "Provider rate limit reached. Please retry shortly."
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
return "Provider authentication failed. Check API key."
|
||||
if isinstance(e, openai.BadRequestError):
|
||||
return "Invalid request sent to provider."
|
||||
|
||||
name = type(e).__name__
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if isinstance(e, openai.RateLimitError) or name == "RateLimitError":
|
||||
if name == "RateLimitError":
|
||||
return "Provider rate limit reached. Please retry shortly."
|
||||
if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError":
|
||||
if name == "AuthenticationError":
|
||||
return "Provider authentication failed. Check API key."
|
||||
if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError":
|
||||
if name == "InvalidRequestError":
|
||||
return "Invalid request sent to provider."
|
||||
if name == "OverloadedError":
|
||||
return "Provider is currently overloaded. Please retry."
|
||||
|
|
@ -42,9 +50,18 @@ def get_user_facing_error_message(
|
|||
if name.endswith("ProviderError") or name == "ProviderError":
|
||||
return "Provider request failed."
|
||||
|
||||
message = str(e).strip()
|
||||
if message:
|
||||
return message
|
||||
|
||||
return "Provider request failed unexpectedly."
|
||||
|
||||
|
||||
def format_user_error_preview(exc: Exception, *, max_len: int = 200) -> str:
|
||||
"""Truncate a user-facing error string for short chat replies."""
|
||||
return get_user_facing_error_message(exc)[:max_len]
|
||||
|
||||
|
||||
def append_request_id(message: str, request_id: str | None) -> str:
|
||||
"""Append request_id suffix when available."""
|
||||
base = message.strip() or "Provider request failed unexpectedly."
|
||||
|
|
|
|||
260
core/anthropic/native_messages_request.py
Normal file
260
core/anthropic/native_messages_request.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
"""Native Anthropic Messages request body construction (JSON-ready dicts).
|
||||
|
||||
Provider adapters supply policy via parameters (defaults, OpenRouter post-steps).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_REQUEST_FIELDS = (
|
||||
"model",
|
||||
"messages",
|
||||
"system",
|
||||
"max_tokens",
|
||||
"stop_sequences",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"metadata",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"thinking",
|
||||
"context_management",
|
||||
"output_config",
|
||||
"mcp_servers",
|
||||
"extra_body",
|
||||
)
|
||||
|
||||
# Keys that would override routed canonical request fields if merged from ``extra_body``.
|
||||
_OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS = frozenset(
|
||||
{
|
||||
"model",
|
||||
"messages",
|
||||
"system",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"metadata",
|
||||
"stop_sequences",
|
||||
"context_management",
|
||||
"output_config",
|
||||
"mcp_servers",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterExtraBodyError(ValueError):
|
||||
"""``extra_body`` contained reserved keys that would override canonical fields."""
|
||||
|
||||
|
||||
def validate_openrouter_extra_body(extra: Any) -> None:
|
||||
"""Reject ``extra_body`` keys that must not override routed request fields."""
|
||||
if not isinstance(extra, dict) or not extra:
|
||||
return
|
||||
bad = _OPENROUTER_EXTRA_BODY_FORBIDDEN_KEYS & extra.keys()
|
||||
if bad:
|
||||
raise OpenRouterExtraBodyError(
|
||||
f"extra_body must not override canonical request fields: {sorted(bad)}"
|
||||
)
|
||||
|
||||
|
||||
_INTERNAL_FIELDS = {
|
||||
"thinking",
|
||||
"extra_body",
|
||||
}
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert Pydantic models and lightweight objects into JSON-ready values."""
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(exclude_none=True)
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: _serialize_value(item)
|
||||
for key, item in value.items()
|
||||
if item is not None
|
||||
}
|
||||
if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
|
||||
return [_serialize_value(item) for item in value]
|
||||
if value is None or isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
if hasattr(value, "__dict__"):
|
||||
return {
|
||||
key: _serialize_value(item)
|
||||
for key, item in vars(value).items()
|
||||
if not key.startswith("_") and item is not None
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
def _dump_request_fields(request_data: Any) -> dict[str, Any]:
|
||||
"""Extract the public request fields (OpenRouter-style explicit field list)."""
|
||||
if isinstance(request_data, BaseModel):
|
||||
return request_data.model_dump(exclude_none=True)
|
||||
|
||||
dumped: dict[str, Any] = {}
|
||||
for field in _REQUEST_FIELDS:
|
||||
value = getattr(request_data, field, None)
|
||||
if value is not None:
|
||||
dumped[field] = _serialize_value(value)
|
||||
return dumped
|
||||
|
||||
|
||||
def sanitize_native_messages_thinking_policy(
|
||||
messages: Any, *, thinking_enabled: bool
|
||||
) -> Any:
|
||||
"""Filter assistant message thinking blocks for upstream native Anthropic JSON.
|
||||
|
||||
When ``thinking_enabled`` is false, remove ``thinking`` and ``redacted_thinking``
|
||||
history so disabled policy is not undermined by prior turns.
|
||||
|
||||
When true, keep ``redacted_thinking`` and signed ``thinking``; remove only
|
||||
unsigned plain ``thinking`` blocks (not replayable).
|
||||
"""
|
||||
if not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
sanitized_messages: list[Any] = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
sanitized_messages.append(message)
|
||||
continue
|
||||
|
||||
if message.get("role") != "assistant":
|
||||
sanitized_messages.append(message)
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
sanitized_messages.append(message)
|
||||
continue
|
||||
|
||||
if not thinking_enabled:
|
||||
sanitized_content = [
|
||||
block
|
||||
for block in content
|
||||
if not (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") in ("thinking", "redacted_thinking")
|
||||
)
|
||||
]
|
||||
else:
|
||||
sanitized_content = [
|
||||
block
|
||||
for block in content
|
||||
if not (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "thinking"
|
||||
and not isinstance(block.get("signature"), str)
|
||||
)
|
||||
]
|
||||
|
||||
sanitized_message = dict(message)
|
||||
sanitized_message["content"] = sanitized_content or ""
|
||||
sanitized_messages.append(sanitized_message)
|
||||
|
||||
return sanitized_messages
|
||||
|
||||
|
||||
def _normalize_system_prompt_for_openrouter(system: Any) -> Any:
|
||||
"""Flatten Claude SDK system blocks for OpenRouter's native endpoint."""
|
||||
if not isinstance(system, list):
|
||||
return system
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in system:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text_parts.append(block["text"])
|
||||
return "\n\n".join(text_parts).strip() if text_parts else system
|
||||
|
||||
|
||||
def _apply_openrouter_reasoning_policy(body: dict[str, Any], thinking_cfg: Any) -> None:
|
||||
"""Map Anthropic thinking controls onto OpenRouter reasoning controls."""
|
||||
reasoning = body.setdefault("reasoning", {"enabled": True})
|
||||
if not isinstance(reasoning, dict):
|
||||
return
|
||||
reasoning.setdefault("enabled", True)
|
||||
if not isinstance(thinking_cfg, dict):
|
||||
return
|
||||
budget_tokens = thinking_cfg.get("budget_tokens")
|
||||
if isinstance(budget_tokens, int):
|
||||
reasoning.setdefault("max_tokens", budget_tokens)
|
||||
|
||||
|
||||
def build_base_native_anthropic_request_body(
|
||||
request: Any,
|
||||
*,
|
||||
default_max_tokens: int,
|
||||
thinking_enabled: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Serialize a Pydantic messages request to a generic native Anthropic body."""
|
||||
body = request.model_dump(exclude_none=True)
|
||||
|
||||
body.pop("extra_body", None)
|
||||
|
||||
if "thinking" in body:
|
||||
thinking_cfg = body.pop("thinking")
|
||||
if thinking_enabled and isinstance(thinking_cfg, dict):
|
||||
thinking_payload: dict[str, Any] = {"type": "enabled"}
|
||||
budget_tokens = thinking_cfg.get("budget_tokens")
|
||||
if isinstance(budget_tokens, int):
|
||||
thinking_payload["budget_tokens"] = budget_tokens
|
||||
body["thinking"] = thinking_payload
|
||||
|
||||
if "max_tokens" not in body:
|
||||
body["max_tokens"] = default_max_tokens
|
||||
|
||||
if "messages" in body:
|
||||
body["messages"] = sanitize_native_messages_thinking_policy(
|
||||
body["messages"],
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
|
||||
return body
|
||||
|
||||
|
||||
def build_openrouter_native_request_body(
|
||||
request_data: Any,
|
||||
*,
|
||||
thinking_enabled: bool,
|
||||
default_max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Build an Anthropic-format request body for OpenRouter (policy hooks built-in)."""
|
||||
dumped_request = _dump_request_fields(request_data)
|
||||
request_extra = dumped_request.pop("extra_body", None)
|
||||
thinking_cfg = dumped_request.get("thinking")
|
||||
body: dict[str, Any] = {
|
||||
key: value
|
||||
for key, value in dumped_request.items()
|
||||
if key not in _INTERNAL_FIELDS
|
||||
}
|
||||
|
||||
if isinstance(request_extra, dict):
|
||||
validate_openrouter_extra_body(request_extra)
|
||||
body.update(request_extra)
|
||||
|
||||
body["messages"] = sanitize_native_messages_thinking_policy(
|
||||
body.get("messages"),
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
if "system" in body:
|
||||
body["system"] = _normalize_system_prompt_for_openrouter(body["system"])
|
||||
body["stream"] = True
|
||||
if body.get("max_tokens") is None:
|
||||
body["max_tokens"] = default_max_tokens
|
||||
|
||||
if thinking_enabled:
|
||||
_apply_openrouter_reasoning_policy(body, thinking_cfg)
|
||||
|
||||
return body
|
||||
313
core/anthropic/native_sse_block_policy.py
Normal file
313
core/anthropic/native_sse_block_policy.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Shared native Anthropic SSE thinking policy, block remapping, and overlap repair.
|
||||
|
||||
Used by :class:`OpenRouterProvider` and line-mode
|
||||
:class:`providers.anthropic_messages.AnthropicMessagesTransport` providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"NativeSseBlockPolicyState",
|
||||
"format_native_sse_event",
|
||||
"is_terminal_openrouter_done_event",
|
||||
"parse_native_sse_event",
|
||||
"transform_native_sse_block_event",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _UpstreamBlockState:
|
||||
"""Per-upstream content block: segment index and liveness in the model stream."""
|
||||
|
||||
block_type: str
|
||||
down_index: int
|
||||
open: bool
|
||||
last_start_block: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NativeSseBlockPolicyState:
|
||||
"""Track per-upstream content blocks and remapped Anthropic ``index`` field."""
|
||||
|
||||
next_index: int = 0
|
||||
by_upstream: dict[int, _UpstreamBlockState] = field(default_factory=dict)
|
||||
dropped_indexes: set[int] = field(default_factory=set)
|
||||
pending_suppressed_stops: set[int] = field(default_factory=set)
|
||||
message_stopped: bool = False
|
||||
|
||||
|
||||
def format_native_sse_event(event_name: str | None, data_text: str) -> str:
|
||||
"""Format an SSE event from its event name and data payload."""
|
||||
lines: list[str] = []
|
||||
if event_name:
|
||||
lines.append(f"event: {event_name}")
|
||||
lines.extend(f"data: {line}" for line in data_text.splitlines())
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def parse_native_sse_event(event: str) -> tuple[str | None, str]:
|
||||
"""Extract the event name and raw data payload from an SSE event."""
|
||||
event_name = None
|
||||
data_lines: list[str] = []
|
||||
for line in event.strip().splitlines():
|
||||
if line.startswith("event:"):
|
||||
event_name = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_lines.append(line[5:].lstrip())
|
||||
return event_name, "\n".join(data_lines)
|
||||
|
||||
|
||||
def is_terminal_openrouter_done_event(event_name: str | None, data_text: str) -> bool:
|
||||
"""Return whether an event is OpenAI-style terminal noise (``[DONE]``)."""
|
||||
return (event_name is None or event_name in {"data", "done"}) and (
|
||||
data_text.strip().upper() == "[DONE]"
|
||||
)
|
||||
|
||||
|
||||
def _delta_type_to_block_kind(delta_type: Any) -> str | None:
|
||||
"""Map a content_block_delta type to a content block kind (text/thinking/tool_use)."""
|
||||
if not isinstance(delta_type, str):
|
||||
return None
|
||||
if delta_type in {"thinking_delta", "signature_delta"}:
|
||||
return "thinking"
|
||||
if delta_type == "text_delta":
|
||||
return "text"
|
||||
if delta_type == "input_json_delta":
|
||||
return "tool_use"
|
||||
return None
|
||||
|
||||
|
||||
def _synthetic_start_content_block(
|
||||
block_kind: str,
|
||||
*,
|
||||
upstream_index: int,
|
||||
stored_tool_block: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a `content_block` for a `content_block_start` with empty streaming fields."""
|
||||
if block_kind == "tool_use":
|
||||
if (
|
||||
isinstance(stored_tool_block, dict)
|
||||
and stored_tool_block.get("type") == "tool_use"
|
||||
):
|
||||
tool_id = stored_tool_block.get("id")
|
||||
name = stored_tool_block.get("name")
|
||||
inp = stored_tool_block.get("input")
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"id": tool_id
|
||||
if isinstance(tool_id, str) and tool_id
|
||||
else f"toolu_or_{upstream_index}",
|
||||
"name": name if isinstance(name, str) else "",
|
||||
"input": inp if isinstance(inp, dict) else {},
|
||||
}
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"id": f"toolu_or_{upstream_index}",
|
||||
"name": "",
|
||||
"input": {},
|
||||
}
|
||||
if block_kind == "thinking":
|
||||
return {"type": "thinking", "thinking": ""}
|
||||
if block_kind == "text":
|
||||
return {"type": "text", "text": ""}
|
||||
return {"type": "text", "text": ""}
|
||||
|
||||
|
||||
def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool:
|
||||
if not isinstance(block_type, str):
|
||||
return False
|
||||
if block_type.startswith("redacted_thinking"):
|
||||
return not thinking_enabled
|
||||
return not thinking_enabled and "thinking" in block_type
|
||||
|
||||
|
||||
def _synthetic_close_other_open_blocks(
|
||||
state: NativeSseBlockPolicyState, current_upstream: int
|
||||
) -> str:
|
||||
"""Close every open block except `current_upstream` and track duplicate upstream stops."""
|
||||
out: list[str] = []
|
||||
for upstream, seg in list(state.by_upstream.items()):
|
||||
if upstream == current_upstream or not seg.open:
|
||||
continue
|
||||
out.append(
|
||||
format_native_sse_event(
|
||||
"content_block_stop",
|
||||
json.dumps(
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": seg.down_index,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
seg.open = False
|
||||
state.pending_suppressed_stops.add(upstream)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _allocate_new_segment(
|
||||
state: NativeSseBlockPolicyState,
|
||||
upstream_index: int,
|
||||
block_type: str,
|
||||
*,
|
||||
last_start_block: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
"""Assign a new downstream `index` for a segment and record upstream state."""
|
||||
new_idx = state.next_index
|
||||
state.next_index += 1
|
||||
state.by_upstream[upstream_index] = _UpstreamBlockState(
|
||||
block_type=block_type,
|
||||
down_index=new_idx,
|
||||
open=True,
|
||||
last_start_block=last_start_block,
|
||||
)
|
||||
return new_idx
|
||||
|
||||
|
||||
def transform_native_sse_block_event(
|
||||
event: str,
|
||||
state: NativeSseBlockPolicyState,
|
||||
*,
|
||||
thinking_enabled: bool,
|
||||
) -> str | None:
|
||||
"""Normalize native Anthropic SSE events and enforce local thinking policy."""
|
||||
event_name, data_text = parse_native_sse_event(event)
|
||||
if not event_name or not data_text:
|
||||
return event
|
||||
|
||||
try:
|
||||
payload = json.loads(data_text)
|
||||
except json.JSONDecodeError:
|
||||
return event
|
||||
|
||||
if event_name == "content_block_start":
|
||||
block = payload.get("content_block")
|
||||
if not isinstance(block, dict):
|
||||
return event
|
||||
block_type = block.get("type")
|
||||
upstream_index = payload.get("index")
|
||||
if not isinstance(upstream_index, int):
|
||||
return event
|
||||
if _should_drop_block_type(block_type, thinking_enabled=thinking_enabled):
|
||||
state.dropped_indexes.add(upstream_index)
|
||||
return None
|
||||
|
||||
if not isinstance(block_type, str):
|
||||
return event
|
||||
prefix = _synthetic_close_other_open_blocks(state, upstream_index)
|
||||
stored = copy.deepcopy(block)
|
||||
new_idx = _allocate_new_segment(
|
||||
state,
|
||||
upstream_index,
|
||||
block_type=block_type,
|
||||
last_start_block=stored,
|
||||
)
|
||||
payload["index"] = new_idx
|
||||
return prefix + format_native_sse_event(event_name, json.dumps(payload))
|
||||
|
||||
if event_name == "content_block_delta":
|
||||
delta = payload.get("delta")
|
||||
if not isinstance(delta, dict):
|
||||
return event
|
||||
delta_type = delta.get("type")
|
||||
upstream_index = payload.get("index")
|
||||
if not isinstance(upstream_index, int):
|
||||
return event
|
||||
if upstream_index in state.dropped_indexes:
|
||||
return None
|
||||
if _should_drop_block_type(delta_type, thinking_enabled=thinking_enabled):
|
||||
return None
|
||||
|
||||
block_kind = _delta_type_to_block_kind(delta_type)
|
||||
if block_kind is None:
|
||||
return event
|
||||
|
||||
seg = state.by_upstream.get(upstream_index)
|
||||
if seg and seg.open:
|
||||
payload["index"] = seg.down_index
|
||||
return format_native_sse_event(event_name, json.dumps(payload))
|
||||
|
||||
if seg is not None and not seg.open:
|
||||
# More deltas for an upstream block after a synthetic (or other) close:
|
||||
# reopen with a new downstream `index` and emit a synthetic `content_block_start` first.
|
||||
state.pending_suppressed_stops.discard(upstream_index)
|
||||
carry = seg.last_start_block
|
||||
new_idx = _allocate_new_segment(
|
||||
state,
|
||||
upstream_index,
|
||||
block_type=block_kind,
|
||||
last_start_block=carry,
|
||||
)
|
||||
stored_tool = (
|
||||
carry
|
||||
if isinstance(carry, dict) and carry.get("type") == "tool_use"
|
||||
else None
|
||||
)
|
||||
start_payload = {
|
||||
"type": "content_block_start",
|
||||
"index": new_idx,
|
||||
"content_block": _synthetic_start_content_block(
|
||||
block_kind,
|
||||
upstream_index=upstream_index,
|
||||
stored_tool_block=stored_tool,
|
||||
),
|
||||
}
|
||||
prefix = format_native_sse_event(
|
||||
"content_block_start", json.dumps(start_payload)
|
||||
)
|
||||
payload["index"] = new_idx
|
||||
return prefix + format_native_sse_event(event_name, json.dumps(payload))
|
||||
|
||||
# Delta with no prior `content_block_start` in this stream
|
||||
if block_kind in ("text", "tool_use"):
|
||||
synthetic_block = _synthetic_start_content_block(
|
||||
block_kind,
|
||||
upstream_index=upstream_index,
|
||||
)
|
||||
new_idx = _allocate_new_segment(
|
||||
state,
|
||||
upstream_index,
|
||||
block_type=block_kind,
|
||||
last_start_block=copy.deepcopy(synthetic_block),
|
||||
)
|
||||
start_payload = {
|
||||
"type": "content_block_start",
|
||||
"index": new_idx,
|
||||
"content_block": synthetic_block,
|
||||
}
|
||||
prefix = format_native_sse_event(
|
||||
"content_block_start", json.dumps(start_payload)
|
||||
)
|
||||
payload["index"] = new_idx
|
||||
return prefix + format_native_sse_event(event_name, json.dumps(payload))
|
||||
# thinking: pass through raw (unusual upstream shape)
|
||||
return event
|
||||
|
||||
if event_name == "content_block_stop":
|
||||
upstream_index = payload.get("index")
|
||||
if not isinstance(upstream_index, int):
|
||||
return event
|
||||
if upstream_index in state.dropped_indexes:
|
||||
return None
|
||||
if upstream_index in state.pending_suppressed_stops:
|
||||
state.pending_suppressed_stops.discard(upstream_index)
|
||||
return None
|
||||
|
||||
seg = state.by_upstream.get(upstream_index)
|
||||
if seg is not None and seg.open:
|
||||
payload["index"] = seg.down_index
|
||||
seg.open = False
|
||||
return format_native_sse_event(event_name, json.dumps(payload))
|
||||
if seg is not None:
|
||||
# Spurious or duplicate `content_block_stop` for a closed block.
|
||||
return None
|
||||
if not thinking_enabled:
|
||||
return None
|
||||
return event
|
||||
|
||||
return event
|
||||
34
core/anthropic/provider_stream_error.py
Normal file
34
core/anthropic/provider_stream_error.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""Canonical Anthropic-style SSE sequence for provider-side streaming errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from core.anthropic.sse import SSEBuilder
|
||||
|
||||
|
||||
def iter_provider_stream_error_sse_events(
|
||||
*,
|
||||
request: Any,
|
||||
input_tokens: int,
|
||||
error_message: str,
|
||||
sent_any_event: bool,
|
||||
log_raw_sse_events: bool,
|
||||
message_id: str | None = None,
|
||||
) -> Iterator[str]:
|
||||
"""Yield message_start (if needed), a text block with the error, then message_delta/stop."""
|
||||
mid = message_id or f"msg_{uuid.uuid4()}"
|
||||
model = getattr(request, "model", "") or ""
|
||||
sse = SSEBuilder(
|
||||
mid,
|
||||
model,
|
||||
input_tokens,
|
||||
log_raw_events=log_raw_sse_events,
|
||||
)
|
||||
if not sent_any_event:
|
||||
yield sse.message_start()
|
||||
yield from sse.emit_error(error_message)
|
||||
yield sse.message_delta("end_turn", 1)
|
||||
yield sse.message_stop()
|
||||
14
core/anthropic/server_tool_sse.py
Normal file
14
core/anthropic/server_tool_sse.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""SSE content_block ``type`` values for Anthropic web server tools (local handlers).
|
||||
|
||||
Shared by :mod:`api.web_tools` and stream contract tests to avoid drift.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
SERVER_TOOL_USE: Final = "server_tool_use"
|
||||
WEB_SEARCH_TOOL_RESULT: Final = "web_search_tool_result"
|
||||
WEB_FETCH_TOOL_RESULT: Final = "web_fetch_tool_result"
|
||||
WEB_SEARCH_TOOL_RESULT_ERROR: Final = "web_search_tool_result_error"
|
||||
WEB_FETCH_TOOL_ERROR: Final = "web_fetch_tool_error"
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
"""SSE event builder for Anthropic-format streaming responses."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -14,6 +15,13 @@ except Exception:
|
|||
ENCODER = None
|
||||
|
||||
|
||||
# Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy.
|
||||
ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = {
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
STOP_REASON_MAP = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
|
|
@ -29,6 +37,11 @@ def map_stop_reason(openai_reason: str | None) -> str:
|
|||
)
|
||||
|
||||
|
||||
def format_sse_event(event_type: str, data: dict) -> str:
|
||||
"""Format one Anthropic-style SSE event (no logging)."""
|
||||
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallState:
|
||||
"""State for a single streaming tool call."""
|
||||
|
|
@ -40,6 +53,7 @@ class ToolCallState:
|
|||
started: bool = False
|
||||
task_arg_buffer: str = ""
|
||||
task_args_emitted: bool = False
|
||||
pre_start_args: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -58,7 +72,25 @@ class ContentBlockManager:
|
|||
self.next_index += 1
|
||||
return idx
|
||||
|
||||
def ensure_tool_state(self, index: int) -> ToolCallState:
|
||||
"""Create tool stream state for ``index`` when the first tool delta arrives."""
|
||||
if index not in self.tool_states:
|
||||
self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="")
|
||||
return self.tool_states[index]
|
||||
|
||||
def set_stream_tool_id(self, index: int, tool_id: str | None) -> None:
|
||||
"""Record OpenAI tool call id before ``content_block_start`` (split-stream providers)."""
|
||||
if not tool_id:
|
||||
return
|
||||
state = self.ensure_tool_state(index)
|
||||
state.tool_id = str(tool_id)
|
||||
|
||||
def register_tool_name(self, index: int, name: str) -> None:
|
||||
"""Record tool name fragments as they arrive from chunked OpenAI streams.
|
||||
|
||||
Names may be split across deltas; later chunks can extend (``ab`` + ``c``)
|
||||
or repeat prefixes, so we merge conservatively.
|
||||
"""
|
||||
if index not in self.tool_states:
|
||||
self.tool_states[index] = ToolCallState(
|
||||
block_index=-1, tool_id="", name=name
|
||||
|
|
@ -82,8 +114,7 @@ class ContentBlockManager:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
if args_json.get("run_in_background") is not False:
|
||||
args_json["run_in_background"] = False
|
||||
_normalize_task_run_in_background(args_json)
|
||||
|
||||
state.task_args_emitted = True
|
||||
state.task_arg_buffer = ""
|
||||
|
|
@ -98,16 +129,17 @@ class ContentBlockManager:
|
|||
out = "{}"
|
||||
try:
|
||||
args_json = json.loads(state.task_arg_buffer)
|
||||
if args_json.get("run_in_background") is not False:
|
||||
args_json["run_in_background"] = False
|
||||
_normalize_task_run_in_background(args_json)
|
||||
out = json.dumps(args_json)
|
||||
except Exception as e:
|
||||
prefix = state.task_arg_buffer[:120]
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
digest = hashlib.sha256(
|
||||
state.task_arg_buffer.encode("utf-8", errors="replace")
|
||||
).hexdigest()[:16]
|
||||
logger.warning(
|
||||
"Task args invalid JSON (id={} len={} prefix={!r}): {}",
|
||||
"Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}",
|
||||
state.tool_id or "unknown",
|
||||
len(state.task_arg_buffer),
|
||||
prefix,
|
||||
digest,
|
||||
e,
|
||||
)
|
||||
|
||||
|
|
@ -117,20 +149,41 @@ class ContentBlockManager:
|
|||
return results
|
||||
|
||||
|
||||
def _normalize_task_run_in_background(args_json: dict) -> None:
|
||||
"""Force Claude Code Task subagents to run in foreground (single shared rule)."""
|
||||
if args_json.get("run_in_background") is not False:
|
||||
args_json["run_in_background"] = False
|
||||
|
||||
|
||||
class SSEBuilder:
|
||||
"""Builder for Anthropic SSE streaming events."""
|
||||
|
||||
def __init__(self, message_id: str, model: str, input_tokens: int = 0):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
model: str,
|
||||
input_tokens: int = 0,
|
||||
*,
|
||||
log_raw_events: bool = False,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.model = model
|
||||
self.input_tokens = input_tokens
|
||||
self._log_raw_events = log_raw_events
|
||||
self.blocks = ContentBlockManager()
|
||||
self._accumulated_text_parts: list[str] = []
|
||||
self._accumulated_reasoning_parts: list[str] = []
|
||||
|
||||
def _format_event(self, event_type: str, data: dict) -> str:
|
||||
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||||
event_str = format_sse_event(event_type, data)
|
||||
if self._log_raw_events:
|
||||
logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
|
||||
else:
|
||||
logger.debug(
|
||||
"SSE_EVENT: event_type={} serialized_bytes={}",
|
||||
event_type,
|
||||
len(event_str.encode("utf-8")),
|
||||
)
|
||||
return event_str
|
||||
|
||||
def message_start(self) -> str:
|
||||
|
|
@ -289,10 +342,7 @@ class SSEBuilder:
|
|||
yield self.stop_text_block()
|
||||
|
||||
def close_all_blocks(self) -> Iterator[str]:
|
||||
if self.blocks.thinking_started:
|
||||
yield self.stop_thinking_block()
|
||||
if self.blocks.text_started:
|
||||
yield self.stop_text_block()
|
||||
yield from self.close_content_blocks()
|
||||
for tool_index, state in list(self.blocks.tool_states.items()):
|
||||
if state.started:
|
||||
yield self.stop_tool_block(tool_index)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Neutral SSE parsing and Anthropic stream shape assertions.
|
||||
|
||||
Used by default CI contract tests and re-exported from ``smoke.lib.sse`` for
|
||||
opt-in smoke scenarios.
|
||||
Used by default CI contract tests and by opt-in live smoke scenarios.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -11,6 +10,36 @@ from collections.abc import Iterable
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .server_tool_sse import (
|
||||
SERVER_TOOL_USE,
|
||||
WEB_FETCH_TOOL_RESULT,
|
||||
WEB_SEARCH_TOOL_RESULT,
|
||||
)
|
||||
|
||||
# Content blocks that only use content_block_start/stop (no deltas), including
|
||||
# Anthropic server tools and eager text emitted in a single start event.
|
||||
_NO_DELTA_BLOCK_KINDS = frozenset(
|
||||
{
|
||||
SERVER_TOOL_USE,
|
||||
WEB_SEARCH_TOOL_RESULT,
|
||||
WEB_FETCH_TOOL_RESULT,
|
||||
"text_eager",
|
||||
"redacted_thinking",
|
||||
}
|
||||
)
|
||||
|
||||
_ALLOWED_BLOCK_START_TYPES = frozenset(
|
||||
{
|
||||
"text",
|
||||
"thinking",
|
||||
"tool_use",
|
||||
"redacted_thinking",
|
||||
SERVER_TOOL_USE,
|
||||
WEB_SEARCH_TOOL_RESULT,
|
||||
WEB_FETCH_TOOL_RESULT,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SSEEvent:
|
||||
|
|
@ -86,35 +115,46 @@ def assert_anthropic_stream_contract(
|
|||
raise AssertionError(f"unexpected SSE error event: {event.data}")
|
||||
|
||||
if event.event == "content_block_start":
|
||||
index = _event_index(event)
|
||||
index = event_index(event)
|
||||
block = event.data.get("content_block", {})
|
||||
assert isinstance(block, dict), event.data
|
||||
block_type = str(block.get("type", ""))
|
||||
assert block_type in {"text", "thinking", "tool_use"}, event.data
|
||||
assert block_type in _ALLOWED_BLOCK_START_TYPES, event.data
|
||||
assert index not in open_blocks, f"block {index} started twice"
|
||||
assert index not in seen_blocks, f"block {index} reused after stop"
|
||||
open_blocks[index] = block_type
|
||||
if block_type == "text" and str(block.get("text", "")).strip():
|
||||
storage = "text_eager"
|
||||
else:
|
||||
storage = block_type
|
||||
open_blocks[index] = storage
|
||||
seen_blocks.add(index)
|
||||
continue
|
||||
|
||||
if event.event == "content_block_delta":
|
||||
index = _event_index(event)
|
||||
index = event_index(event)
|
||||
assert index in open_blocks, f"delta for unopened block {index}"
|
||||
kind = open_blocks[index]
|
||||
assert kind not in _NO_DELTA_BLOCK_KINDS, (
|
||||
f"unexpected delta for start/stop-only block {kind} at index {index}"
|
||||
)
|
||||
delta = event.data.get("delta", {})
|
||||
assert isinstance(delta, dict), event.data
|
||||
delta_type = str(delta.get("type", ""))
|
||||
if kind == "thinking":
|
||||
assert delta_type in (
|
||||
"thinking_delta",
|
||||
"signature_delta",
|
||||
), f"block {index} is {kind}, got {delta_type}"
|
||||
continue
|
||||
expected = {
|
||||
"text": "text_delta",
|
||||
"thinking": "thinking_delta",
|
||||
"tool_use": "input_json_delta",
|
||||
}[open_blocks[index]]
|
||||
assert delta_type == expected, (
|
||||
f"block {index} is {open_blocks[index]}, got {delta_type}"
|
||||
)
|
||||
}[kind]
|
||||
assert delta_type == expected, f"block {index} is {kind}, got {delta_type}"
|
||||
continue
|
||||
|
||||
if event.event == "content_block_stop":
|
||||
index = _event_index(event)
|
||||
index = event_index(event)
|
||||
assert index in open_blocks, f"stop for unopened block {index}"
|
||||
open_blocks.pop(index)
|
||||
|
||||
|
|
@ -129,6 +169,12 @@ def event_names(events: list[SSEEvent]) -> list[str]:
|
|||
def text_content(events: list[SSEEvent]) -> str:
|
||||
parts: list[str] = []
|
||||
for event in events:
|
||||
if event.event == "content_block_start":
|
||||
block = event.data.get("content_block", {})
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
eager = str(block.get("text", ""))
|
||||
if eager:
|
||||
parts.append(eager)
|
||||
delta = event.data.get("delta", {})
|
||||
if isinstance(delta, dict) and delta.get("type") == "text_delta":
|
||||
parts.append(str(delta.get("text", "")))
|
||||
|
|
@ -152,7 +198,8 @@ def has_tool_use(events: list[SSEEvent]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _event_index(event: SSEEvent) -> int:
|
||||
def event_index(event: SSEEvent) -> int:
|
||||
"""Return the content block ``index`` field from an SSE payload (strict)."""
|
||||
value = event.data.get("index")
|
||||
assert isinstance(value, int), event.data
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -68,6 +68,27 @@ def get_token_count(
|
|||
total_tokens += len(ENCODER.encode(json.dumps(content)))
|
||||
total_tokens += len(ENCODER.encode(str(tool_use_id)))
|
||||
total_tokens += 8
|
||||
elif b_type in (
|
||||
"server_tool_use",
|
||||
"web_search_tool_result",
|
||||
"web_fetch_tool_result",
|
||||
):
|
||||
if hasattr(block, "model_dump"):
|
||||
blob: object = block.model_dump()
|
||||
else:
|
||||
blob = block
|
||||
try:
|
||||
total_tokens += len(
|
||||
ENCODER.encode(
|
||||
json.dumps(blob, default=str, ensure_ascii=False)
|
||||
)
|
||||
)
|
||||
except (TypeError, ValueError, OverflowError) as e:
|
||||
logger.debug(
|
||||
"Block encode fallback b_type={} err={}", b_type, e
|
||||
)
|
||||
total_tokens += len(ENCODER.encode(str(blob)))
|
||||
total_tokens += 12
|
||||
else:
|
||||
logger.debug(
|
||||
"Unexpected block type %r, falling back to json/str encoding",
|
||||
|
|
|
|||
60
core/rate_limit.py
Normal file
60
core/rate_limit.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Shared strict sliding-window rate limiting primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
|
||||
class StrictSlidingWindowLimiter:
|
||||
"""Strict sliding window limiter.
|
||||
|
||||
Guarantees: at most ``rate_limit`` acquisitions in any interval of length
|
||||
``rate_window`` (seconds).
|
||||
|
||||
Implemented as an async context manager so call sites can do::
|
||||
|
||||
async with limiter:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, rate_limit: int, rate_window: float) -> None:
|
||||
if rate_limit <= 0:
|
||||
raise ValueError("rate_limit must be > 0")
|
||||
if rate_window <= 0:
|
||||
raise ValueError("rate_window must be > 0")
|
||||
|
||||
self._rate_limit = int(rate_limit)
|
||||
self._rate_window = float(rate_window)
|
||||
self._times: deque[float] = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
while True:
|
||||
wait_time = 0.0
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._rate_window
|
||||
|
||||
while self._times and self._times[0] <= cutoff:
|
||||
self._times.popleft()
|
||||
|
||||
if len(self._times) < self._rate_limit:
|
||||
self._times.append(now)
|
||||
return
|
||||
|
||||
oldest = self._times[0]
|
||||
wait_time = max(0.0, (oldest + self._rate_window) - now)
|
||||
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def __aenter__(self) -> StrictSlidingWindowLimiter:
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
67
messaging/cli_event_constants.py
Normal file
67
messaging/cli_event_constants.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""CLI event types and status-line mapping for transcript / UI updates."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
# Status message prefixes used to filter our own messages (ignore echo)
|
||||
STATUS_MESSAGE_PREFIXES = (
|
||||
"⏳",
|
||||
"💭",
|
||||
"🔧",
|
||||
"✅",
|
||||
"❌",
|
||||
"🚀",
|
||||
"🤖",
|
||||
"📋",
|
||||
"📊",
|
||||
"🔄",
|
||||
)
|
||||
|
||||
# Event types that update the transcript (frozenset for O(1) membership)
|
||||
TRANSCRIPT_EVENT_TYPES = frozenset(
|
||||
{
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_chunk",
|
||||
"thinking_stop",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_chunk",
|
||||
"text_stop",
|
||||
"tool_use_start",
|
||||
"tool_use_delta",
|
||||
"tool_use_stop",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"block_stop",
|
||||
"error",
|
||||
}
|
||||
)
|
||||
|
||||
# Event type -> (emoji, label) for status updates (O(1) lookup)
|
||||
_EVENT_STATUS_MAP: dict[str, tuple[str, str]] = {
|
||||
"thinking_start": ("🧠", "Claude is thinking..."),
|
||||
"thinking_delta": ("🧠", "Claude is thinking..."),
|
||||
"thinking_chunk": ("🧠", "Claude is thinking..."),
|
||||
"text_start": ("🧠", "Claude is working..."),
|
||||
"text_delta": ("🧠", "Claude is working..."),
|
||||
"text_chunk": ("🧠", "Claude is working..."),
|
||||
"tool_result": ("⏳", "Executing tools..."),
|
||||
}
|
||||
|
||||
|
||||
def get_status_for_event(
|
||||
ptype: str,
|
||||
parsed: dict[str, Any],
|
||||
format_status_fn: Callable[..., str],
|
||||
) -> str | None:
|
||||
"""Return status string for event type, or None if no status update needed."""
|
||||
entry = _EVENT_STATUS_MAP.get(ptype)
|
||||
if entry is not None:
|
||||
emoji, label = entry
|
||||
return format_status_fn(emoji, label)
|
||||
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
|
||||
if parsed.get("name") == "Task":
|
||||
return format_status_fn("🤖", "Subagent working...")
|
||||
return format_status_fn("⏳", "Executing tools...")
|
||||
return None
|
||||
|
|
@ -9,12 +9,14 @@ from typing import Any
|
|||
from loguru import logger
|
||||
|
||||
|
||||
def parse_cli_event(event: Any) -> list[dict]:
|
||||
def parse_cli_event(event: Any, *, log_raw_cli: bool = False) -> list[dict]:
|
||||
"""
|
||||
Parse a CLI event and return a structured result.
|
||||
|
||||
Args:
|
||||
event: Raw event dictionary from CLI
|
||||
log_raw_cli: When True, log full error text from the CLI. Default is
|
||||
metadata-only (lengths / exit codes) to avoid leaking user content.
|
||||
|
||||
Returns:
|
||||
List of parsed event dicts. Empty list if not recognized.
|
||||
|
|
@ -140,7 +142,11 @@ def parse_cli_event(event: Any) -> list[dict]:
|
|||
if etype == "error":
|
||||
err = event.get("error")
|
||||
msg = err.get("message") if isinstance(err, dict) else str(err)
|
||||
logger.info(f"CLI_PARSER: Parsed error event: {msg}")
|
||||
if log_raw_cli:
|
||||
logger.info("CLI_PARSER: Parsed error event: {}", msg)
|
||||
else:
|
||||
mlen = len(msg) if isinstance(msg, str) else 0
|
||||
logger.info("CLI_PARSER: Parsed error event: message_chars={}", mlen)
|
||||
return [{"type": "error", "message": msg}]
|
||||
elif etype == "exit":
|
||||
code = event.get("code", 0)
|
||||
|
|
@ -151,7 +157,19 @@ def parse_cli_event(event: Any) -> list[dict]:
|
|||
else:
|
||||
# Non-zero exit is an error
|
||||
error_msg = stderr if stderr else f"Process exited with code {code}"
|
||||
logger.warning(f"CLI_PARSER: Error exit (code={code}): {error_msg}")
|
||||
if log_raw_cli:
|
||||
logger.warning(
|
||||
"CLI_PARSER: Error exit (code={}): {}",
|
||||
code,
|
||||
error_msg,
|
||||
)
|
||||
else:
|
||||
em = error_msg if isinstance(error_msg, str) else str(error_msg)
|
||||
logger.warning(
|
||||
"CLI_PARSER: Error exit (code={}): message_chars={}",
|
||||
code,
|
||||
len(em),
|
||||
)
|
||||
return [
|
||||
{"type": "error", "message": error_msg},
|
||||
{"type": "complete", "status": "failed"},
|
||||
|
|
|
|||
|
|
@ -7,13 +7,12 @@ Uses tree-based queuing for message ordering.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from core.anthropic import format_user_error_preview, get_user_facing_error_message
|
||||
|
||||
from .cli_event_constants import STATUS_MESSAGE_PREFIXES
|
||||
from .command_dispatcher import (
|
||||
dispatch_command,
|
||||
message_kind_for_command,
|
||||
|
|
@ -21,8 +20,10 @@ from .command_dispatcher import (
|
|||
)
|
||||
from .event_parser import parse_cli_event
|
||||
from .models import IncomingMessage
|
||||
from .node_event_pipeline import handle_session_info_event, process_parsed_cli_event
|
||||
from .platforms.base import MessagingPlatform, SessionManagerInterface
|
||||
from .rendering.profiles import build_rendering_profile
|
||||
from .safe_diagnostics import format_exception_for_log
|
||||
from .session import SessionStore
|
||||
from .transcript import RenderCtx, TranscriptBuffer
|
||||
from .trees.queue_manager import (
|
||||
|
|
@ -31,54 +32,7 @@ from .trees.queue_manager import (
|
|||
MessageTree,
|
||||
TreeQueueManager,
|
||||
)
|
||||
|
||||
# Status message prefixes used to filter our own messages (ignore echo)
|
||||
STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄")
|
||||
|
||||
# Event types that update the transcript (frozenset for O(1) membership)
|
||||
TRANSCRIPT_EVENT_TYPES = frozenset(
|
||||
{
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_chunk",
|
||||
"thinking_stop",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_chunk",
|
||||
"text_stop",
|
||||
"tool_use_start",
|
||||
"tool_use_delta",
|
||||
"tool_use_stop",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"block_stop",
|
||||
"error",
|
||||
}
|
||||
)
|
||||
|
||||
# Event type -> (emoji, label) for status updates (O(1) lookup)
|
||||
_EVENT_STATUS_MAP = {
|
||||
"thinking_start": ("🧠", "Claude is thinking..."),
|
||||
"thinking_delta": ("🧠", "Claude is thinking..."),
|
||||
"thinking_chunk": ("🧠", "Claude is thinking..."),
|
||||
"text_start": ("🧠", "Claude is working..."),
|
||||
"text_delta": ("🧠", "Claude is working..."),
|
||||
"text_chunk": ("🧠", "Claude is working..."),
|
||||
"tool_result": ("⏳", "Executing tools..."),
|
||||
}
|
||||
|
||||
|
||||
def _get_status_for_event(ptype: str, parsed: dict, format_status_fn) -> str | None:
|
||||
"""Return status string for event type, or None if no status update needed."""
|
||||
entry = _EVENT_STATUS_MAP.get(ptype)
|
||||
if entry is not None:
|
||||
emoji, label = entry
|
||||
return format_status_fn(emoji, label)
|
||||
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
|
||||
if parsed.get("name") == "Task":
|
||||
return format_status_fn("🤖", "Subagent working...")
|
||||
return format_status_fn("⏳", "Executing tools...")
|
||||
return None
|
||||
from .ui_updates import ThrottledTranscriptEditor
|
||||
|
||||
|
||||
class ClaudeMessageHandler:
|
||||
|
|
@ -97,10 +51,21 @@ class ClaudeMessageHandler:
|
|||
platform: MessagingPlatform,
|
||||
cli_manager: SessionManagerInterface,
|
||||
session_store: SessionStore,
|
||||
*,
|
||||
debug_platform_edits: bool = False,
|
||||
debug_subagent_stack: bool = False,
|
||||
log_raw_messaging_content: bool = False,
|
||||
log_raw_cli_diagnostics: bool = False,
|
||||
log_messaging_error_details: bool = False,
|
||||
):
|
||||
self.platform = platform
|
||||
self.cli_manager = cli_manager
|
||||
self.session_store = session_store
|
||||
self._debug_platform_edits = debug_platform_edits
|
||||
self._debug_subagent_stack = debug_subagent_stack
|
||||
self._log_raw_messaging_content = log_raw_messaging_content
|
||||
self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
|
||||
self._log_messaging_error_details = log_messaging_error_details
|
||||
self._tree_queue = TreeQueueManager(
|
||||
queue_update_callback=self.update_queue_positions,
|
||||
node_started_callback=self.mark_node_processing,
|
||||
|
|
@ -137,8 +102,10 @@ class ClaudeMessageHandler:
|
|||
Determines if this is a new conversation or reply,
|
||||
creates/extends the message tree, and queues for processing.
|
||||
"""
|
||||
text_preview = (incoming.text or "")[:80]
|
||||
if len(incoming.text or "") > 80:
|
||||
raw = incoming.text or ""
|
||||
if self._log_raw_messaging_content:
|
||||
text_preview = raw[:80]
|
||||
if len(raw) > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}",
|
||||
|
|
@ -147,6 +114,14 @@ class ClaudeMessageHandler:
|
|||
incoming.reply_to_message_id,
|
||||
text_preview,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_len={}",
|
||||
incoming.chat_id,
|
||||
incoming.message_id,
|
||||
incoming.reply_to_message_id,
|
||||
len(raw),
|
||||
)
|
||||
|
||||
with logger.contextualize(
|
||||
chat_id=incoming.chat_id, node_id=incoming.message_id
|
||||
|
|
@ -169,7 +144,12 @@ class ClaudeMessageHandler:
|
|||
kind=message_kind_for_command(cmd_base),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record incoming message_id: {e}")
|
||||
logger.debug(
|
||||
"Failed to record incoming message_id: {}",
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
|
||||
if await dispatch_command(self, incoming, cmd_base):
|
||||
return
|
||||
|
|
@ -276,7 +256,12 @@ class ClaudeMessageHandler:
|
|||
try:
|
||||
queued_ids = await tree.get_queue_snapshot()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read queue snapshot: {e}")
|
||||
logger.warning(
|
||||
"Failed to read queue snapshot: {}",
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if not queued_ids:
|
||||
|
|
@ -317,86 +302,12 @@ class ClaudeMessageHandler:
|
|||
self,
|
||||
) -> tuple[TranscriptBuffer, RenderCtx]:
|
||||
"""Create transcript buffer and render context for node processing."""
|
||||
transcript = TranscriptBuffer(show_tool_results=False)
|
||||
transcript = TranscriptBuffer(
|
||||
show_tool_results=False,
|
||||
debug_subagent_stack=self._debug_subagent_stack,
|
||||
)
|
||||
return transcript, self.get_render_ctx()
|
||||
|
||||
async def _handle_session_info_event(
|
||||
self,
|
||||
event_data: dict,
|
||||
tree: MessageTree | None,
|
||||
node_id: str,
|
||||
captured_session_id: str | None,
|
||||
temp_session_id: str | None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Handle session_info event; return updated (captured_session_id, temp_session_id)."""
|
||||
if event_data.get("type") != "session_info":
|
||||
return captured_session_id, temp_session_id
|
||||
|
||||
real_session_id = event_data.get("session_id")
|
||||
if not real_session_id or not temp_session_id:
|
||||
return captured_session_id, temp_session_id
|
||||
|
||||
await self.cli_manager.register_real_session_id(
|
||||
temp_session_id, real_session_id
|
||||
)
|
||||
if tree and real_session_id:
|
||||
await tree.update_state(
|
||||
node_id,
|
||||
MessageState.IN_PROGRESS,
|
||||
session_id=real_session_id,
|
||||
)
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
|
||||
return real_session_id, None
|
||||
|
||||
async def _process_parsed_event(
|
||||
self,
|
||||
parsed: dict,
|
||||
transcript: TranscriptBuffer,
|
||||
update_ui,
|
||||
last_status: str | None,
|
||||
had_transcript_events: bool,
|
||||
tree: MessageTree | None,
|
||||
node_id: str,
|
||||
captured_session_id: str | None,
|
||||
) -> tuple[str | None, bool]:
|
||||
"""Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
|
||||
ptype = parsed.get("type") or ""
|
||||
|
||||
if ptype in TRANSCRIPT_EVENT_TYPES:
|
||||
transcript.apply(parsed)
|
||||
had_transcript_events = True
|
||||
|
||||
status = _get_status_for_event(ptype, parsed, self.format_status)
|
||||
if status is not None:
|
||||
await update_ui(status)
|
||||
last_status = status
|
||||
elif ptype == "block_stop":
|
||||
await update_ui(last_status, force=True)
|
||||
elif ptype == "complete":
|
||||
if not had_transcript_events:
|
||||
transcript.apply({"type": "text_chunk", "text": "Done."})
|
||||
logger.info("HANDLER: Task complete, updating UI")
|
||||
await update_ui(self.format_status("✅", "Complete"), force=True)
|
||||
if tree and captured_session_id:
|
||||
await tree.update_state(
|
||||
node_id,
|
||||
MessageState.COMPLETED,
|
||||
session_id=captured_session_id,
|
||||
)
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
elif ptype == "error":
|
||||
error_msg = parsed.get("message", "Unknown error")
|
||||
logger.error(f"HANDLER: Error event received: {error_msg}")
|
||||
logger.info("HANDLER: Updating UI with error status")
|
||||
await update_ui(self.format_status("❌", "Error"), force=True)
|
||||
if tree:
|
||||
await self._propagate_error_to_children(
|
||||
node_id, error_msg, "Parent task failed"
|
||||
)
|
||||
|
||||
return last_status, had_transcript_events
|
||||
|
||||
async def _process_node(
|
||||
self,
|
||||
node_id: str,
|
||||
|
|
@ -426,8 +337,6 @@ class ClaudeMessageHandler:
|
|||
|
||||
transcript, render_ctx = self._create_transcript_and_render_ctx()
|
||||
|
||||
last_ui_update = 0.0
|
||||
last_displayed_text = None
|
||||
had_transcript_events = False
|
||||
captured_session_id = None
|
||||
temp_session_id = None
|
||||
|
|
@ -439,52 +348,21 @@ class ClaudeMessageHandler:
|
|||
if parent_session_id:
|
||||
logger.info(f"Will fork from parent session: {parent_session_id}")
|
||||
|
||||
async def update_ui(status: str | None = None, force: bool = False) -> None:
|
||||
nonlocal last_ui_update, last_displayed_text, last_status
|
||||
now = time.time()
|
||||
if not force and now - last_ui_update < 1.0:
|
||||
return
|
||||
|
||||
last_ui_update = now
|
||||
if status is not None:
|
||||
last_status = status
|
||||
try:
|
||||
display = transcript.render(
|
||||
render_ctx,
|
||||
limit_chars=self._get_limit_chars(),
|
||||
status=status,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Transcript render failed for node {node_id}: {e}")
|
||||
return
|
||||
if display and display != last_displayed_text:
|
||||
logger.debug(
|
||||
"PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
|
||||
node_id,
|
||||
chat_id,
|
||||
status_msg_id,
|
||||
bool(force),
|
||||
status,
|
||||
len(display),
|
||||
)
|
||||
if os.getenv("DEBUG_PLATFORM_EDITS") == "1":
|
||||
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
|
||||
else:
|
||||
head = display[:500]
|
||||
tail = display[-500:] if len(display) > 500 else ""
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head)
|
||||
if tail:
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail)
|
||||
last_displayed_text = display
|
||||
try:
|
||||
await self.platform.queue_edit_message(
|
||||
chat_id,
|
||||
status_msg_id,
|
||||
display,
|
||||
editor = ThrottledTranscriptEditor(
|
||||
platform=self.platform,
|
||||
parse_mode=self._parse_mode(),
|
||||
get_limit_chars=self._get_limit_chars,
|
||||
transcript=transcript,
|
||||
render_ctx=render_ctx,
|
||||
node_id=node_id,
|
||||
chat_id=chat_id,
|
||||
status_msg_id=status_msg_id,
|
||||
debug_platform_edits=self._debug_platform_edits,
|
||||
log_messaging_error_details=self._log_messaging_error_details,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update platform for node {node_id}: {e}")
|
||||
|
||||
async def update_ui(status: str | None = None, force: bool = False) -> None:
|
||||
await editor.update(status, force=force)
|
||||
|
||||
try:
|
||||
try:
|
||||
|
|
@ -531,20 +409,28 @@ class ClaudeMessageHandler:
|
|||
(
|
||||
captured_session_id,
|
||||
temp_session_id,
|
||||
) = await self._handle_session_info_event(
|
||||
event_data, tree, node_id, captured_session_id, temp_session_id
|
||||
) = await handle_session_info_event(
|
||||
event_data,
|
||||
tree,
|
||||
node_id,
|
||||
captured_session_id,
|
||||
temp_session_id,
|
||||
cli_manager=self.cli_manager,
|
||||
session_store=self.session_store,
|
||||
)
|
||||
if event_data.get("type") == "session_info":
|
||||
continue
|
||||
|
||||
parsed_list = parse_cli_event(event_data)
|
||||
parsed_list = parse_cli_event(
|
||||
event_data, log_raw_cli=self._log_raw_cli_diagnostics
|
||||
)
|
||||
logger.debug(f"HANDLER: Parsed {len(parsed_list)} events from CLI")
|
||||
|
||||
for parsed in parsed_list:
|
||||
(
|
||||
last_status,
|
||||
had_transcript_events,
|
||||
) = await self._process_parsed_event(
|
||||
) = await process_parsed_cli_event(
|
||||
parsed,
|
||||
transcript,
|
||||
update_ui,
|
||||
|
|
@ -553,6 +439,10 @@ class ClaudeMessageHandler:
|
|||
tree,
|
||||
node_id,
|
||||
captured_session_id,
|
||||
session_store=self.session_store,
|
||||
format_status=self.format_status,
|
||||
propagate_error_to_children=self._propagate_error_to_children,
|
||||
log_messaging_error_details=self._log_messaging_error_details,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
|
|
@ -575,9 +465,12 @@ class ClaudeMessageHandler:
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"HANDLER: Task failed with exception: {type(e).__name__}: {e}"
|
||||
"HANDLER: Task failed with exception: {}",
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
error_msg = get_user_facing_error_message(e)[:200]
|
||||
error_msg = format_user_error_preview(e)
|
||||
transcript.apply({"type": "error", "message": error_msg})
|
||||
await update_ui(self.format_status("💥", "Task Failed"), force=True)
|
||||
if tree:
|
||||
|
|
@ -595,7 +488,13 @@ class ClaudeMessageHandler:
|
|||
elif temp_session_id:
|
||||
await self.cli_manager.remove_session(temp_session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to remove session for node {node_id}: {e}")
|
||||
logger.debug(
|
||||
"Failed to remove session for node {}: {}",
|
||||
node_id,
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
|
||||
async def _propagate_error_to_children(
|
||||
self,
|
||||
|
|
@ -691,7 +590,12 @@ class ClaudeMessageHandler:
|
|||
platform, chat_id, str(msg_id), direction="out", kind=kind
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record message_id: {e}")
|
||||
logger.debug(
|
||||
"Failed to record message_id: {}",
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
|
||||
def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
|
||||
"""Update status messages and persist tree state for cancelled nodes."""
|
||||
|
|
|
|||
|
|
@ -6,65 +6,16 @@ using a strict sliding window algorithm and a task queue.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import get_settings
|
||||
from core.rate_limit import StrictSlidingWindowLimiter as SlidingWindowLimiter
|
||||
|
||||
class SlidingWindowLimiter:
|
||||
"""Strict sliding window limiter.
|
||||
|
||||
Guarantees: at most `rate_limit` acquisitions in any interval of length
|
||||
`rate_window` (seconds).
|
||||
|
||||
Implemented as an async context manager so call sites can do:
|
||||
async with limiter:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, rate_limit: int, rate_window: float) -> None:
|
||||
if rate_limit <= 0:
|
||||
raise ValueError("rate_limit must be > 0")
|
||||
if rate_window <= 0:
|
||||
raise ValueError("rate_window must be > 0")
|
||||
|
||||
self._rate_limit = int(rate_limit)
|
||||
self._rate_window = float(rate_window)
|
||||
self._times: deque[float] = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
while True:
|
||||
wait_time = 0.0
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._rate_window
|
||||
|
||||
while self._times and self._times[0] <= cutoff:
|
||||
self._times.popleft()
|
||||
|
||||
if len(self._times) < self._rate_limit:
|
||||
self._times.append(now)
|
||||
return
|
||||
|
||||
oldest = self._times[0]
|
||||
wait_time = max(0.0, (oldest + self._rate_window) - now)
|
||||
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def __aenter__(self) -> SlidingWindowLimiter:
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
from .safe_diagnostics import format_exception_for_log
|
||||
|
||||
|
||||
class MessagingRateLimiter:
|
||||
|
|
@ -82,23 +33,29 @@ class MessagingRateLimiter:
|
|||
return super().__new__(cls)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> MessagingRateLimiter:
|
||||
"""Get the singleton instance of the limiter."""
|
||||
async def get_instance(
|
||||
cls,
|
||||
*,
|
||||
rate_limit: int = 1,
|
||||
rate_window: float = 1.0,
|
||||
) -> MessagingRateLimiter:
|
||||
"""Get the singleton instance of the limiter.
|
||||
|
||||
``rate_limit`` and ``rate_window`` apply only when the singleton is first
|
||||
created. Call :meth:`shutdown_instance` before changing parameters.
|
||||
"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
cls._instance = cls(rate_limit=rate_limit, rate_window=rate_window)
|
||||
# Start the background worker (tracked for graceful shutdown).
|
||||
cls._instance._start_worker()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, *, rate_limit: int, rate_window: float) -> None:
|
||||
# Prevent double initialization in singleton
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
rate_limit = int(os.getenv("MESSAGING_RATE_LIMIT", "1"))
|
||||
rate_window = float(os.getenv("MESSAGING_RATE_WINDOW", "2.0"))
|
||||
|
||||
self.limiter = SlidingWindowLimiter(rate_limit, rate_window)
|
||||
# Custom queue state - using deque for O(1) popleft
|
||||
self._queue_list: deque[str] = deque() # Deque of dedup_keys in order
|
||||
|
|
@ -189,14 +146,26 @@ class MessagingRateLimiter:
|
|||
asyncio.get_event_loop().time() + wait_secs
|
||||
)
|
||||
else:
|
||||
d = get_settings().log_messaging_error_details
|
||||
logger.error(
|
||||
f"Error in limiter worker for key {dedup_key}: {type(e).__name__}: {e}"
|
||||
"Error in limiter worker for key {}: {}",
|
||||
dedup_key,
|
||||
format_exception_for_log(e, log_full_message=d),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
d = get_settings().log_messaging_error_details
|
||||
if d:
|
||||
logger.error(
|
||||
f"MessagingRateLimiter worker critical error: {e}", exc_info=True
|
||||
"MessagingRateLimiter worker critical error: {}",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"MessagingRateLimiter worker critical error: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
|
@ -223,7 +192,11 @@ class MessagingRateLimiter:
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"MessagingRateLimiter worker shutdown error: {e}")
|
||||
d = get_settings().log_messaging_error_details
|
||||
logger.debug(
|
||||
"MessagingRateLimiter worker shutdown error: {}",
|
||||
format_exception_for_log(e, log_full_message=d),
|
||||
)
|
||||
finally:
|
||||
self._worker_task = None
|
||||
|
||||
|
|
@ -296,14 +269,29 @@ class MessagingRateLimiter:
|
|||
x in error_msg for x in ["connect", "timeout", "broken"]
|
||||
):
|
||||
wait = 2**attempt
|
||||
d = get_settings().log_messaging_error_details
|
||||
if d:
|
||||
logger.warning(
|
||||
f"Limiter fire_and_forget transient error (attempt {attempt + 1}): {e}. Retrying in {wait}s..."
|
||||
"Limiter fire_and_forget transient error (attempt {}): {}. Retrying in {}s...",
|
||||
attempt + 1,
|
||||
e,
|
||||
wait,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Limiter fire_and_forget transient error (attempt {}): exc_type={}. Retrying in {}s...",
|
||||
attempt + 1,
|
||||
type(e).__name__,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
d = get_settings().log_messaging_error_details
|
||||
logger.error(
|
||||
f"Final error in fire_and_forget for key {dedup_key}: {type(e).__name__}: {e}"
|
||||
"Final error in fire_and_forget for key {}: {}",
|
||||
dedup_key,
|
||||
format_exception_for_log(e, log_full_message=d),
|
||||
)
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
|
|
|
|||
103
messaging/node_event_pipeline.py
Normal file
103
messaging/node_event_pipeline.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""CLI event handling for a single queued node (transcript + session + errors)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .cli_event_constants import TRANSCRIPT_EVENT_TYPES, get_status_for_event
|
||||
from .platforms.base import SessionManagerInterface
|
||||
from .safe_diagnostics import text_len_hint
|
||||
from .session import SessionStore
|
||||
from .transcript import TranscriptBuffer
|
||||
from .trees.queue_manager import MessageState, MessageTree
|
||||
|
||||
|
||||
async def handle_session_info_event(
|
||||
event_data: dict[str, Any],
|
||||
tree: MessageTree | None,
|
||||
node_id: str,
|
||||
captured_session_id: str | None,
|
||||
temp_session_id: str | None,
|
||||
*,
|
||||
cli_manager: SessionManagerInterface,
|
||||
session_store: SessionStore,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Handle session_info event; return updated (captured_session_id, temp_session_id)."""
|
||||
if event_data.get("type") != "session_info":
|
||||
return captured_session_id, temp_session_id
|
||||
|
||||
real_session_id = event_data.get("session_id")
|
||||
if not real_session_id or not temp_session_id:
|
||||
return captured_session_id, temp_session_id
|
||||
|
||||
await cli_manager.register_real_session_id(temp_session_id, real_session_id)
|
||||
if tree and real_session_id:
|
||||
await tree.update_state(
|
||||
node_id,
|
||||
MessageState.IN_PROGRESS,
|
||||
session_id=real_session_id,
|
||||
)
|
||||
session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
|
||||
return real_session_id, None
|
||||
|
||||
|
||||
async def process_parsed_cli_event(
|
||||
parsed: dict[str, Any],
|
||||
transcript: TranscriptBuffer,
|
||||
update_ui: Callable[..., Awaitable[None]],
|
||||
last_status: str | None,
|
||||
had_transcript_events: bool,
|
||||
tree: MessageTree | None,
|
||||
node_id: str,
|
||||
captured_session_id: str | None,
|
||||
*,
|
||||
session_store: SessionStore,
|
||||
format_status: Callable[..., str],
|
||||
propagate_error_to_children: Callable[[str, str, str], Awaitable[None]],
|
||||
log_messaging_error_details: bool = False,
|
||||
) -> tuple[str | None, bool]:
|
||||
"""Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
|
||||
ptype = parsed.get("type") or ""
|
||||
|
||||
if ptype in TRANSCRIPT_EVENT_TYPES:
|
||||
transcript.apply(parsed)
|
||||
had_transcript_events = True
|
||||
|
||||
status = get_status_for_event(ptype, parsed, format_status)
|
||||
if status is not None:
|
||||
await update_ui(status)
|
||||
last_status = status
|
||||
elif ptype == "block_stop":
|
||||
await update_ui(last_status, force=True)
|
||||
elif ptype == "complete":
|
||||
if not had_transcript_events:
|
||||
transcript.apply({"type": "text_chunk", "text": "Done."})
|
||||
logger.info("HANDLER: Task complete, updating UI")
|
||||
await update_ui(format_status("✅", "Complete"), force=True)
|
||||
if tree and captured_session_id:
|
||||
await tree.update_state(
|
||||
node_id,
|
||||
MessageState.COMPLETED,
|
||||
session_id=captured_session_id,
|
||||
)
|
||||
session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
elif ptype == "error":
|
||||
error_msg = parsed.get("message", "Unknown error")
|
||||
if log_messaging_error_details:
|
||||
logger.error("HANDLER: Error event received: {}", error_msg)
|
||||
else:
|
||||
em = error_msg if isinstance(error_msg, str) else str(error_msg)
|
||||
logger.error(
|
||||
"HANDLER: Error event received: message_chars={}",
|
||||
text_len_hint(em),
|
||||
)
|
||||
logger.info("HANDLER: Updating UI with error status")
|
||||
await update_ui(format_status("❌", "Error"), force=True)
|
||||
if tree:
|
||||
await propagate_error_to_children(node_id, error_msg, "Parent task failed")
|
||||
|
||||
return last_status, had_transcript_events
|
||||
|
|
@ -6,7 +6,6 @@ Implements MessagingPlatform for Discord using discord.py.
|
|||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
|
|
@ -14,7 +13,7 @@ from typing import Any, cast
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from core.anthropic import format_user_error_preview
|
||||
|
||||
from ..models import IncomingMessage
|
||||
from ..rendering.discord_markdown import format_status_discord
|
||||
|
|
@ -97,15 +96,18 @@ class DiscordPlatform(MessagingPlatform):
|
|||
whisper_device: str = "cpu",
|
||||
hf_token: str = "",
|
||||
nvidia_nim_api_key: str = "",
|
||||
messaging_rate_limit: int = 1,
|
||||
messaging_rate_window: float = 1.0,
|
||||
log_raw_messaging_content: bool = False,
|
||||
log_api_error_tracebacks: bool = False,
|
||||
):
|
||||
if not DISCORD_AVAILABLE:
|
||||
raise ImportError(
|
||||
"discord.py is required. Install with: pip install discord.py"
|
||||
)
|
||||
|
||||
self.bot_token = bot_token or os.getenv("DISCORD_BOT_TOKEN")
|
||||
raw_channels = allowed_channel_ids or os.getenv("ALLOWED_DISCORD_CHANNELS")
|
||||
self.allowed_channel_ids = _parse_allowed_channels(raw_channels)
|
||||
self.bot_token = bot_token
|
||||
self.allowed_channel_ids = _parse_allowed_channels(allowed_channel_ids)
|
||||
|
||||
if not self.bot_token:
|
||||
logger.warning("DISCORD_BOT_TOKEN not set")
|
||||
|
|
@ -130,6 +132,10 @@ class DiscordPlatform(MessagingPlatform):
|
|||
self._voice_note_enabled = voice_note_enabled
|
||||
self._whisper_model = whisper_model
|
||||
self._whisper_device = whisper_device
|
||||
self._messaging_rate_limit = messaging_rate_limit
|
||||
self._messaging_rate_window = messaging_rate_window
|
||||
self._log_raw_messaging_content = log_raw_messaging_content
|
||||
self._log_api_error_tracebacks = log_api_error_tracebacks
|
||||
|
||||
async def _handle_client_message(self, message: Any) -> None:
|
||||
"""Adapter entry point used by the internal discord client."""
|
||||
|
|
@ -234,23 +240,40 @@ class DiscordPlatform(MessagingPlatform):
|
|||
status_message_id=status_msg_id,
|
||||
)
|
||||
|
||||
if self._log_raw_messaging_content:
|
||||
logger.info(
|
||||
"DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}",
|
||||
channel_id,
|
||||
message_id,
|
||||
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
|
||||
(
|
||||
transcribed[:80] + "..."
|
||||
if len(transcribed) > 80
|
||||
else transcribed
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"DISCORD_VOICE: chat_id={} message_id={} transcribed_len={}",
|
||||
channel_id,
|
||||
message_id,
|
||||
len(transcribed),
|
||||
)
|
||||
|
||||
await self._message_handler(incoming)
|
||||
return True
|
||||
except ValueError as e:
|
||||
await message.reply(get_user_facing_error_message(e)[:200])
|
||||
await message.reply(format_user_error_preview(e))
|
||||
return True
|
||||
except ImportError as e:
|
||||
await message.reply(get_user_facing_error_message(e)[:200])
|
||||
await message.reply(format_user_error_preview(e))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Voice transcription failed: {e}")
|
||||
if self._log_api_error_tracebacks:
|
||||
logger.error("Voice transcription failed: {}", e)
|
||||
else:
|
||||
logger.error(
|
||||
"Voice transcription failed: exc_type={}", type(e).__name__
|
||||
)
|
||||
await message.reply(
|
||||
"Could not transcribe voice note. Please try again or send text."
|
||||
)
|
||||
|
|
@ -285,8 +308,10 @@ class DiscordPlatform(MessagingPlatform):
|
|||
else None
|
||||
)
|
||||
|
||||
text_preview = (message.content or "")[:80]
|
||||
if len(message.content or "") > 80:
|
||||
raw_content = message.content or ""
|
||||
if self._log_raw_messaging_content:
|
||||
text_preview = raw_content[:80]
|
||||
if len(raw_content) > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
|
||||
|
|
@ -295,6 +320,14 @@ class DiscordPlatform(MessagingPlatform):
|
|||
reply_to,
|
||||
text_preview,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"DISCORD_MSG: chat_id={} message_id={} reply_to={} text_len={}",
|
||||
channel_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
len(raw_content),
|
||||
)
|
||||
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
|
@ -313,13 +346,14 @@ class DiscordPlatform(MessagingPlatform):
|
|||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
if self._log_api_error_tracebacks:
|
||||
logger.error("Error handling message: {}", e)
|
||||
else:
|
||||
logger.error("Error handling message: exc_type={}", type(e).__name__)
|
||||
with contextlib.suppress(Exception):
|
||||
await self.send_message(
|
||||
channel_id,
|
||||
format_status_discord(
|
||||
"Error:", get_user_facing_error_message(e)[:200]
|
||||
),
|
||||
format_status_discord("Error:", format_user_error_preview(e)),
|
||||
reply_to=message_id,
|
||||
)
|
||||
|
||||
|
|
@ -336,7 +370,10 @@ class DiscordPlatform(MessagingPlatform):
|
|||
|
||||
from ..limiter import MessagingRateLimiter
|
||||
|
||||
self._limiter = await MessagingRateLimiter.get_instance()
|
||||
self._limiter = await MessagingRateLimiter.get_instance(
|
||||
rate_limit=self._messaging_rate_limit,
|
||||
rate_window=self._messaging_rate_window,
|
||||
)
|
||||
|
||||
self._start_task = asyncio.create_task(
|
||||
self._client.start(self.bot_token),
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ class MessagingPlatformOptions:
|
|||
whisper_device: str = "cpu"
|
||||
hf_token: str = ""
|
||||
nvidia_nim_api_key: str = ""
|
||||
messaging_rate_limit: int = 1
|
||||
messaging_rate_window: float = 1.0
|
||||
log_raw_messaging_content: bool = False
|
||||
log_api_error_tracebacks: bool = False
|
||||
|
||||
|
||||
def create_messaging_platform(
|
||||
|
|
@ -64,6 +68,10 @@ def create_messaging_platform(
|
|||
whisper_device=opts.whisper_device,
|
||||
hf_token=opts.hf_token,
|
||||
nvidia_nim_api_key=opts.nvidia_nim_api_key,
|
||||
messaging_rate_limit=opts.messaging_rate_limit,
|
||||
messaging_rate_window=opts.messaging_rate_window,
|
||||
log_raw_messaging_content=opts.log_raw_messaging_content,
|
||||
log_api_error_tracebacks=opts.log_api_error_tracebacks,
|
||||
)
|
||||
|
||||
if platform_type == "discord":
|
||||
|
|
@ -82,6 +90,10 @@ def create_messaging_platform(
|
|||
whisper_device=opts.whisper_device,
|
||||
hf_token=opts.hf_token,
|
||||
nvidia_nim_api_key=opts.nvidia_nim_api_key,
|
||||
messaging_rate_limit=opts.messaging_rate_limit,
|
||||
messaging_rate_window=opts.messaging_rate_window,
|
||||
log_raw_messaging_content=opts.log_raw_messaging_content,
|
||||
log_api_error_tracebacks=opts.log_api_error_tracebacks,
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from core.anthropic import format_user_error_preview
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram import Update
|
||||
|
|
@ -68,14 +68,18 @@ class TelegramPlatform(MessagingPlatform):
|
|||
whisper_device: str = "cpu",
|
||||
hf_token: str = "",
|
||||
nvidia_nim_api_key: str = "",
|
||||
messaging_rate_limit: int = 1,
|
||||
messaging_rate_window: float = 1.0,
|
||||
log_raw_messaging_content: bool = False,
|
||||
log_api_error_tracebacks: bool = False,
|
||||
):
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"python-telegram-bot is required. Install with: pip install python-telegram-bot"
|
||||
)
|
||||
|
||||
self.bot_token = bot_token or os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
|
||||
self.bot_token = bot_token
|
||||
self.allowed_user_id = allowed_user_id
|
||||
|
||||
if not self.bot_token:
|
||||
# We don't raise here to allow instantiation for testing/conditional logic,
|
||||
|
|
@ -97,6 +101,10 @@ class TelegramPlatform(MessagingPlatform):
|
|||
self._voice_note_enabled = voice_note_enabled
|
||||
self._whisper_model = whisper_model
|
||||
self._whisper_device = whisper_device
|
||||
self._messaging_rate_limit = messaging_rate_limit
|
||||
self._messaging_rate_window = messaging_rate_window
|
||||
self._log_raw_messaging_content = log_raw_messaging_content
|
||||
self._log_api_error_tracebacks = log_api_error_tracebacks
|
||||
|
||||
async def _register_pending_voice(
|
||||
self, chat_id: str, voice_msg_id: str, status_msg_id: str
|
||||
|
|
@ -172,7 +180,10 @@ class TelegramPlatform(MessagingPlatform):
|
|||
# Initialize rate limiter
|
||||
from ..limiter import MessagingRateLimiter
|
||||
|
||||
self._limiter = await MessagingRateLimiter.get_instance()
|
||||
self._limiter = await MessagingRateLimiter.get_instance(
|
||||
rate_limit=self._messaging_rate_limit,
|
||||
rate_window=self._messaging_rate_window,
|
||||
)
|
||||
|
||||
# Send startup notification
|
||||
try:
|
||||
|
|
@ -187,7 +198,13 @@ class TelegramPlatform(MessagingPlatform):
|
|||
startup_text,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not send startup message: {e}")
|
||||
if self._log_api_error_tracebacks:
|
||||
logger.warning("Could not send startup message: {}", e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not send startup message: exc_type={}",
|
||||
type(e).__name__,
|
||||
)
|
||||
|
||||
logger.info("Telegram platform started (Bot API)")
|
||||
|
||||
|
|
@ -506,8 +523,10 @@ class TelegramPlatform(MessagingPlatform):
|
|||
if getattr(update.message, "message_thread_id", None) is not None
|
||||
else None
|
||||
)
|
||||
text_preview = (update.message.text or "")[:80]
|
||||
if len(update.message.text or "") > 80:
|
||||
raw_text = update.message.text or ""
|
||||
if self._log_raw_messaging_content:
|
||||
text_preview = raw_text[:80]
|
||||
if len(raw_text) > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
|
||||
|
|
@ -516,6 +535,14 @@ class TelegramPlatform(MessagingPlatform):
|
|||
reply_to,
|
||||
text_preview,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_len={}",
|
||||
chat_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
len(raw_text),
|
||||
)
|
||||
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
|
@ -534,11 +561,14 @@ class TelegramPlatform(MessagingPlatform):
|
|||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
if self._log_api_error_tracebacks:
|
||||
logger.error("Error handling message: {}", e)
|
||||
else:
|
||||
logger.error("Error handling message: exc_type={}", type(e).__name__)
|
||||
with contextlib.suppress(Exception):
|
||||
await self.send_message(
|
||||
chat_id,
|
||||
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}",
|
||||
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(format_user_error_preview(e))}",
|
||||
reply_to=incoming.message_id,
|
||||
message_thread_id=thread_id,
|
||||
parse_mode="MarkdownV2",
|
||||
|
|
@ -631,20 +661,37 @@ class TelegramPlatform(MessagingPlatform):
|
|||
status_message_id=status_msg_id,
|
||||
)
|
||||
|
||||
if self._log_raw_messaging_content:
|
||||
logger.info(
|
||||
"TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}",
|
||||
chat_id,
|
||||
message_id,
|
||||
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
|
||||
(
|
||||
transcribed[:80] + "..."
|
||||
if len(transcribed) > 80
|
||||
else transcribed
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"TELEGRAM_VOICE: chat_id={} message_id={} transcribed_len={}",
|
||||
chat_id,
|
||||
message_id,
|
||||
len(transcribed),
|
||||
)
|
||||
|
||||
await self._message_handler(incoming)
|
||||
except ValueError as e:
|
||||
await update.message.reply_text(get_user_facing_error_message(e)[:200])
|
||||
await update.message.reply_text(format_user_error_preview(e))
|
||||
except ImportError as e:
|
||||
await update.message.reply_text(get_user_facing_error_message(e)[:200])
|
||||
await update.message.reply_text(format_user_error_preview(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Voice transcription failed: {e}")
|
||||
if self._log_api_error_tracebacks:
|
||||
logger.error("Voice transcription failed: {}", e)
|
||||
else:
|
||||
logger.error(
|
||||
"Voice transcription failed: exc_type={}", type(e).__name__
|
||||
)
|
||||
await update.message.reply_text(
|
||||
"Could not transcribe voice note. Please try again or send text."
|
||||
)
|
||||
|
|
|
|||
17
messaging/safe_diagnostics.py
Normal file
17
messaging/safe_diagnostics.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Helpers for redacting user-derived content from log lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def format_exception_for_log(exc: BaseException, *, log_full_message: bool) -> str:
|
||||
"""Return exception type and optionally ``str(exc)`` for operator diagnostics."""
|
||||
if log_full_message:
|
||||
return f"{type(exc).__name__}: {exc}"
|
||||
return type(exc).__name__
|
||||
|
||||
|
||||
def text_len_hint(text: str | None) -> int:
|
||||
"""Length of text for metadata-only logging (0 when missing)."""
|
||||
if not text:
|
||||
return 0
|
||||
return len(text)
|
||||
|
|
@ -5,8 +5,10 @@ Provides persistent storage for mapping platform messages to Claude CLI session
|
|||
and message trees for conversation continuation.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
|
@ -22,7 +24,12 @@ class SessionStore:
|
|||
Platform-agnostic: works with any messaging platform.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str = "sessions.json"):
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str = "sessions.json",
|
||||
*,
|
||||
message_log_cap: int | None = None,
|
||||
):
|
||||
self.storage_path = storage_path
|
||||
self._lock = threading.Lock()
|
||||
self._trees: dict[str, dict] = {} # root_id -> tree data
|
||||
|
|
@ -34,11 +41,7 @@ class SessionStore:
|
|||
self._dirty = False
|
||||
self._save_timer: threading.Timer | None = None
|
||||
self._save_debounce_secs = 0.5
|
||||
cap_raw = os.getenv("MAX_MESSAGE_LOG_ENTRIES_PER_CHAT", "").strip()
|
||||
try:
|
||||
self._message_log_cap: int | None = int(cap_raw) if cap_raw else None
|
||||
except ValueError:
|
||||
self._message_log_cap = None
|
||||
self._message_log_cap: int | None = message_log_cap
|
||||
self._load()
|
||||
|
||||
def _make_chat_key(self, platform: str, chat_id: str) -> str:
|
||||
|
|
@ -104,9 +107,22 @@ class SessionStore:
|
|||
}
|
||||
|
||||
def _write_data(self, data: dict) -> None:
|
||||
"""Write data dict to disk. Must be called WITHOUT holding self._lock."""
|
||||
with open(self.storage_path, "w", encoding="utf-8") as f:
|
||||
"""Atomically write data dict to disk. Must be called WITHOUT holding self._lock."""
|
||||
abs_target = os.path.abspath(self.storage_path)
|
||||
dir_name = os.path.dirname(abs_target) or "."
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=dir_name, prefix=".sessions.", suffix=".tmp.json"
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, abs_target)
|
||||
except BaseException:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
|
||||
def _schedule_save(self) -> None:
|
||||
"""Schedule a debounced save. Caller must hold self._lock."""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ the transcript grows over time and older content must be truncated.
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Iterable
|
||||
|
|
@ -210,7 +209,12 @@ class RenderCtx:
|
|||
class TranscriptBuffer:
|
||||
"""Maintains an ordered, truncatable transcript of events."""
|
||||
|
||||
def __init__(self, *, show_tool_results: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
show_tool_results: bool = True,
|
||||
debug_subagent_stack: bool = False,
|
||||
) -> None:
|
||||
self._segments: list[Segment] = []
|
||||
self._open_thinking_by_index: dict[int, ThinkingSegment] = {}
|
||||
self._open_text_by_index: dict[int, TextSegment] = {}
|
||||
|
|
@ -227,7 +231,7 @@ class TranscriptBuffer:
|
|||
self._subagent_stack: list[str] = []
|
||||
# Parallel stack of segments for rendering nested subagents.
|
||||
self._subagent_segments: list[SubagentSegment] = []
|
||||
self._debug_subagent_stack = os.getenv("DEBUG_SUBAGENT_STACK") == "1"
|
||||
self._debug_subagent_stack = debug_subagent_stack
|
||||
|
||||
def _in_subagent(self) -> bool:
|
||||
return bool(self._subagent_stack)
|
||||
|
|
|
|||
|
|
@ -5,32 +5,18 @@ Supports:
|
|||
- NVIDIA NIM: NVIDIA NIM Whisper/Parakeet
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import get_settings
|
||||
from providers.nvidia_nim.voice import (
|
||||
transcribe_audio_file as transcribe_nvidia_nim_audio,
|
||||
)
|
||||
|
||||
# Max file size in bytes (25 MB)
|
||||
MAX_AUDIO_SIZE_BYTES = 25 * 1024 * 1024
|
||||
|
||||
# NVIDIA NIM Whisper model mapping: (function_id, language_code)
|
||||
_NIM_MODEL_MAP: dict[str, tuple[str, str]] = {
|
||||
"nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"),
|
||||
"nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"),
|
||||
"nvidia/parakeet-ctc-0.6b-es": ("None", "es-US"),
|
||||
"nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"),
|
||||
"nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"),
|
||||
"nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"),
|
||||
"nvidia/parakeet-1.1b-rnnt-multilingual-asr": (
|
||||
"71203149-d3b7-4460-8231-1be2543a1fca",
|
||||
"",
|
||||
),
|
||||
"openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"),
|
||||
}
|
||||
|
||||
# Short model names -> full Hugging Face model IDs (for local Whisper)
|
||||
_MODEL_MAP: dict[str, str] = {
|
||||
"tiny": "openai/whisper-tiny",
|
||||
|
|
@ -51,22 +37,19 @@ def _resolve_model_id(whisper_model: str) -> str:
|
|||
return _MODEL_MAP.get(whisper_model, whisper_model)
|
||||
|
||||
|
||||
def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> Any:
|
||||
def _get_pipeline(model_id: str, device: str, hf_token: str = "") -> Any:
|
||||
"""Lazy-load transformers Whisper pipeline. Raises ImportError if not installed."""
|
||||
global _pipeline_cache
|
||||
if device not in ("cpu", "cuda"):
|
||||
raise ValueError(f"whisper_device must be 'cpu' or 'cuda', got {device!r}")
|
||||
resolved_token = (
|
||||
hf_token if hf_token is not None else get_settings().hf_token
|
||||
) or ""
|
||||
resolved_token = hf_token or ""
|
||||
cache_key = (model_id, device, resolved_token)
|
||||
if cache_key not in _pipeline_cache:
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
|
||||
if resolved_token:
|
||||
os.environ["HF_TOKEN"] = resolved_token
|
||||
hf_auth_token = resolved_token or None
|
||||
|
||||
use_cuda = device == "cuda" and torch.cuda.is_available()
|
||||
pipe_device = "cuda:0" if use_cuda else "cpu"
|
||||
|
|
@ -77,9 +60,10 @@ def _get_pipeline(model_id: str, device: str, hf_token: str | None = None) -> An
|
|||
dtype=model_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="sdpa",
|
||||
token=hf_auth_token,
|
||||
)
|
||||
model = model.to(pipe_device)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
processor = AutoProcessor.from_pretrained(model_id, token=hf_auth_token)
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
|
|
@ -119,7 +103,7 @@ def transcribe_audio(
|
|||
file_path: Path to audio file (OGG, MP3, MP4, WAV, M4A supported)
|
||||
mime_type: MIME type of the audio (e.g. "audio/ogg")
|
||||
whisper_model: Model ID or short name (local) or NVIDIA NIM model
|
||||
whisper_device: "cpu" | "cuda" | "nvidia_nim" (defaults to WHISPER_DEVICE env var)
|
||||
whisper_device: "cpu" | "cuda" | "nvidia_nim"
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
|
|
@ -140,8 +124,8 @@ def transcribe_audio(
|
|||
)
|
||||
|
||||
if whisper_device == "nvidia_nim":
|
||||
return _transcribe_nim(
|
||||
file_path, whisper_model, nvidia_nim_api_key=nvidia_nim_api_key
|
||||
return transcribe_nvidia_nim_audio(
|
||||
file_path, whisper_model, api_key=nvidia_nim_api_key
|
||||
)
|
||||
return _transcribe_local(
|
||||
file_path, whisper_model, whisper_device, hf_token=hf_token
|
||||
|
|
@ -169,8 +153,7 @@ def _transcribe_local(
|
|||
) -> str:
|
||||
"""Transcribe using transformers Whisper pipeline."""
|
||||
model_id = _resolve_model_id(whisper_model)
|
||||
token: str | None = hf_token if hf_token else None
|
||||
pipe = _get_pipeline(model_id, whisper_device, hf_token=token)
|
||||
pipe = _get_pipeline(model_id, whisper_device, hf_token=hf_token)
|
||||
audio = _load_audio(file_path)
|
||||
result = pipe(audio, generate_kwargs={"language": "en", "task": "transcribe"})
|
||||
text = result.get("text", "") or ""
|
||||
|
|
@ -179,65 +162,3 @@ def _transcribe_local(
|
|||
result_text = text.strip()
|
||||
logger.debug(f"Local transcription: {len(result_text)} chars")
|
||||
return result_text or "(no speech detected)"
|
||||
|
||||
|
||||
def _transcribe_nim(
|
||||
file_path: Path, model: str, *, nvidia_nim_api_key: str = ""
|
||||
) -> str:
|
||||
"""Transcribe using NVIDIA NIM Whisper API via Riva gRPC client."""
|
||||
try:
|
||||
import riva.client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"NVIDIA NIM transcription requires the voice extra. "
|
||||
"Install with: uv sync --extra voice"
|
||||
) from e
|
||||
|
||||
api_key = nvidia_nim_api_key or get_settings().nvidia_nim_api_key
|
||||
|
||||
# Look up function ID and language code from model mapping
|
||||
model_config = _NIM_MODEL_MAP.get(model)
|
||||
if not model_config:
|
||||
raise ValueError(
|
||||
f"No NVIDIA NIM config found for model: {model}. "
|
||||
f"Supported models: {', '.join(_NIM_MODEL_MAP.keys())}"
|
||||
)
|
||||
function_id, language_code = model_config
|
||||
|
||||
# Riva server configuration
|
||||
server = "grpc.nvcf.nvidia.com:443"
|
||||
|
||||
# Auth with SSL and metadata
|
||||
auth = riva.client.Auth(
|
||||
use_ssl=True,
|
||||
uri=server,
|
||||
metadata_args=[
|
||||
["function-id", function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
],
|
||||
)
|
||||
|
||||
asr_service = riva.client.ASRService(auth)
|
||||
|
||||
# Configure recognition - language_code from model config
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=language_code,
|
||||
max_alternatives=1,
|
||||
verbatim_transcripts=True,
|
||||
)
|
||||
|
||||
# Read audio file
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
# Perform offline recognition
|
||||
response = asr_service.offline_recognize(data, config)
|
||||
|
||||
# Extract text from response - use getattr for safe attribute access
|
||||
transcript = ""
|
||||
results = getattr(response, "results", None)
|
||||
if results and results[0].alternatives:
|
||||
transcript = results[0].alternatives[0].transcript
|
||||
|
||||
logger.debug(f"NIM transcription: {len(transcript)} chars")
|
||||
return transcript or "(no speech detected)"
|
||||
|
|
|
|||
|
|
@ -1,165 +0,0 @@
|
|||
"""Async queue processor for message trees.
|
||||
|
||||
Handles the async processing lifecycle of tree nodes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
|
||||
from .data import MessageNode, MessageState, MessageTree
|
||||
|
||||
|
||||
class TreeQueueProcessor:
|
||||
"""
|
||||
Handles async queue processing for a single tree.
|
||||
|
||||
Separates the async processing logic from the data management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
|
||||
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
|
||||
| None = None,
|
||||
):
|
||||
self._queue_update_callback = queue_update_callback
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
def set_queue_update_callback(
|
||||
self,
|
||||
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
|
||||
) -> None:
|
||||
"""Update the callback used to refresh queue positions."""
|
||||
self._queue_update_callback = queue_update_callback
|
||||
|
||||
def set_node_started_callback(
|
||||
self,
|
||||
node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
|
||||
) -> None:
|
||||
"""Update the callback used when a queued node starts processing."""
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
async def _notify_queue_updated(self, tree: MessageTree) -> None:
|
||||
"""Invoke queue update callback if set."""
|
||||
if not self._queue_update_callback:
|
||||
return
|
||||
try:
|
||||
await self._queue_update_callback(tree)
|
||||
except Exception as e:
|
||||
logger.warning(f"Queue update callback failed: {e}")
|
||||
|
||||
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
|
||||
"""Invoke node started callback if set."""
|
||||
if not self._node_started_callback:
|
||||
return
|
||||
try:
|
||||
await self._node_started_callback(tree, node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Node started callback failed: {e}")
|
||||
|
||||
async def process_node(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node: MessageNode,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a single node and then check the queue."""
|
||||
# Skip if already in terminal state (e.g. from error propagation)
|
||||
if node.state == MessageState.ERROR:
|
||||
logger.info(
|
||||
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
||||
)
|
||||
# Still need to check for next messages
|
||||
await self._process_next(tree, processor)
|
||||
return
|
||||
|
||||
try:
|
||||
await processor(node.node_id, node)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task for node {node.node_id} was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing node {node.node_id}: {e}")
|
||||
await tree.update_state(
|
||||
node.node_id,
|
||||
MessageState.ERROR,
|
||||
error_message=get_user_facing_error_message(e),
|
||||
)
|
||||
finally:
|
||||
async with tree.with_lock():
|
||||
tree.clear_current_node()
|
||||
# Check if there are more messages in the queue
|
||||
await self._process_next(tree, processor)
|
||||
|
||||
async def _process_next(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process the next message in queue, if any."""
|
||||
next_node_id = None
|
||||
node = None
|
||||
async with tree.with_lock():
|
||||
next_node_id = await tree.dequeue()
|
||||
|
||||
if not next_node_id:
|
||||
tree.set_processing_state(None, False)
|
||||
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
|
||||
return
|
||||
|
||||
tree.set_processing_state(next_node_id, True)
|
||||
logger.info(f"Processing next queued node {next_node_id}")
|
||||
|
||||
# Process next node (outside lock)
|
||||
node = tree.get_node(next_node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
|
||||
# Notify that this node has started processing and refresh queue positions.
|
||||
if next_node_id:
|
||||
await self._notify_node_started(tree, next_node_id)
|
||||
await self._notify_queue_updated(tree)
|
||||
|
||||
async def enqueue_and_start(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node_id: str,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Enqueue a node or start processing immediately.
|
||||
|
||||
Args:
|
||||
tree: The message tree
|
||||
node_id: Node to process
|
||||
processor: Async function to process the node
|
||||
|
||||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
async with tree.with_lock():
|
||||
if tree.is_processing:
|
||||
tree.put_queue_unlocked(node_id)
|
||||
queue_size = tree.get_queue_size()
|
||||
logger.info(f"Queued node {node_id}, position {queue_size}")
|
||||
return True
|
||||
else:
|
||||
tree.set_processing_state(node_id, True)
|
||||
|
||||
# Process outside the lock
|
||||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
return False
|
||||
|
||||
def cancel_current(self, tree: MessageTree) -> bool:
|
||||
"""Cancel the currently running task in a tree."""
|
||||
return tree.cancel_current_task()
|
||||
|
|
@ -1,30 +1,345 @@
|
|||
"""Tree-Based Message Queue Manager - Refactored.
|
||||
|
||||
Coordinates data access, async processing, and error handling.
|
||||
Uses TreeRepository for data, TreeQueueProcessor for async logic.
|
||||
"""
|
||||
"""Tree-based message queue: index, async node processor, and public manager API."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import get_settings
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
|
||||
from ..models import IncomingMessage
|
||||
from ..safe_diagnostics import format_exception_for_log
|
||||
from .data import MessageNode, MessageState, MessageTree
|
||||
from .processor import TreeQueueProcessor
|
||||
from .repository import TreeRepository
|
||||
|
||||
|
||||
class TreeRepository:
|
||||
"""
|
||||
In-memory index of trees and node-to-root mappings.
|
||||
|
||||
Used only by :class:`TreeQueueManager`; kept as a named type for tests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._trees: dict[str, MessageTree] = {} # root_id -> tree
|
||||
self._node_to_tree: dict[str, str] = {} # node_id -> root_id
|
||||
|
||||
def get_tree(self, root_id: str) -> MessageTree | None:
|
||||
"""Get a tree by its root ID."""
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_tree_for_node(self, node_id: str) -> MessageTree | None:
|
||||
"""Get the tree containing a given node."""
|
||||
root_id = self._node_to_tree.get(node_id)
|
||||
if not root_id:
|
||||
return None
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_node(self, node_id: str) -> MessageNode | None:
|
||||
"""Get a node from any tree."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_node(node_id) if tree else None
|
||||
|
||||
def add_tree(self, root_id: str, tree: MessageTree) -> None:
|
||||
"""Add a new tree to the repository."""
|
||||
self._trees[root_id] = tree
|
||||
self._node_to_tree[root_id] = root_id
|
||||
logger.debug("TREE_REPO: add_tree root_id={}", root_id)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree."""
|
||||
self._node_to_tree[node_id] = root_id
|
||||
logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node is registered in any tree."""
|
||||
return node_id in self._node_to_tree
|
||||
|
||||
def tree_count(self) -> int:
|
||||
"""Get the number of trees in the repository."""
|
||||
return len(self._trees)
|
||||
|
||||
def is_tree_busy(self, root_id: str) -> bool:
|
||||
"""Check if a tree is currently processing."""
|
||||
tree = self._trees.get(root_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def is_node_tree_busy(self, node_id: str) -> bool:
|
||||
"""Check if the tree containing a node is busy."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def get_queue_size(self, node_id: str) -> int:
|
||||
"""Get queue size for the tree containing a node."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_queue_size() if tree else 0
|
||||
|
||||
def resolve_parent_node_id(self, msg_id: str) -> str | None:
|
||||
"""
|
||||
Resolve a message ID to the actual parent node ID.
|
||||
|
||||
Handles the case where msg_id is a status message ID
|
||||
(which maps to the tree but isn't an actual node).
|
||||
|
||||
Returns:
|
||||
The node_id to use as parent, or None if not found
|
||||
"""
|
||||
tree = self.get_tree_for_node(msg_id)
|
||||
if not tree:
|
||||
return None
|
||||
|
||||
if tree.has_node(msg_id):
|
||||
return msg_id
|
||||
|
||||
node = tree.find_node_by_status_message(msg_id)
|
||||
if node:
|
||||
return node.node_id
|
||||
|
||||
return None
|
||||
|
||||
def get_pending_children(self, node_id: str) -> list[MessageNode]:
|
||||
"""
|
||||
Get all pending child nodes (recursively) of a given node.
|
||||
|
||||
Used for error propagation - when a node fails, its pending
|
||||
children should also be marked as failed.
|
||||
"""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
pending: list[MessageNode] = []
|
||||
stack = [node_id]
|
||||
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
node = tree.get_node(current_id)
|
||||
if not node:
|
||||
continue
|
||||
for child_id in node.children_ids:
|
||||
child = tree.get_node(child_id)
|
||||
if child and child.state == MessageState.PENDING:
|
||||
pending.append(child)
|
||||
stack.append(child_id)
|
||||
|
||||
return pending
|
||||
|
||||
def all_trees(self) -> list[MessageTree]:
|
||||
"""Get all trees in the repository."""
|
||||
return list(self._trees.values())
|
||||
|
||||
def tree_ids(self) -> list[str]:
|
||||
"""Get all tree root IDs."""
|
||||
return list(self._trees.keys())
|
||||
|
||||
def unregister_nodes(self, node_ids: list[str]) -> None:
|
||||
"""Remove node IDs from the node-to-tree mapping."""
|
||||
for nid in node_ids:
|
||||
self._node_to_tree.pop(nid, None)
|
||||
|
||||
def remove_tree(self, root_id: str) -> MessageTree | None:
|
||||
"""
|
||||
Remove a tree and all its node mappings from the repository.
|
||||
|
||||
Returns:
|
||||
The removed tree, or None if not found.
|
||||
"""
|
||||
tree = self._trees.pop(root_id, None)
|
||||
if not tree:
|
||||
return None
|
||||
for node in tree.all_nodes():
|
||||
self._node_to_tree.pop(node.node_id, None)
|
||||
logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
|
||||
return tree
|
||||
|
||||
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
|
||||
"""Get all message IDs (incoming + status) for a given platform/chat."""
|
||||
msg_ids: set[str] = set()
|
||||
for tree in self._trees.values():
|
||||
for node in tree.all_nodes():
|
||||
if str(node.incoming.platform) == str(platform) and str(
|
||||
node.incoming.chat_id
|
||||
) == str(chat_id):
|
||||
if node.incoming.message_id is not None:
|
||||
msg_ids.add(str(node.incoming.message_id))
|
||||
if node.status_message_id:
|
||||
msg_ids.add(str(node.status_message_id))
|
||||
return msg_ids
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return {
|
||||
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
|
||||
"node_to_tree": self._node_to_tree.copy(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> TreeRepository:
|
||||
"""Deserialize from dictionary."""
|
||||
repo = cls()
|
||||
for root_id, tree_data in data.get("trees", {}).items():
|
||||
repo._trees[root_id] = MessageTree.from_dict(tree_data)
|
||||
repo._node_to_tree = data.get("node_to_tree", {})
|
||||
return repo
|
||||
|
||||
|
||||
class TreeQueueProcessor:
|
||||
"""
|
||||
Per-tree async queue processing (one manager owns one processor instance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
|
||||
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
self._queue_update_callback = queue_update_callback
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
def set_queue_update_callback(
|
||||
self,
|
||||
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
|
||||
) -> None:
|
||||
"""Update the callback used to refresh queue positions."""
|
||||
self._queue_update_callback = queue_update_callback
|
||||
|
||||
def set_node_started_callback(
|
||||
self,
|
||||
node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
|
||||
) -> None:
|
||||
"""Update the callback used when a queued node starts processing."""
|
||||
self._node_started_callback = node_started_callback
|
||||
|
||||
async def _notify_queue_updated(self, tree: MessageTree) -> None:
|
||||
"""Invoke queue update callback if set."""
|
||||
if not self._queue_update_callback:
|
||||
return
|
||||
try:
|
||||
await self._queue_update_callback(tree)
|
||||
except Exception as e:
|
||||
d = get_settings().log_messaging_error_details
|
||||
logger.warning(
|
||||
"Queue update callback failed: {}",
|
||||
format_exception_for_log(e, log_full_message=d),
|
||||
)
|
||||
|
||||
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
|
||||
"""Invoke node started callback if set."""
|
||||
if not self._node_started_callback:
|
||||
return
|
||||
try:
|
||||
await self._node_started_callback(tree, node_id)
|
||||
except Exception as e:
|
||||
d = get_settings().log_messaging_error_details
|
||||
logger.warning(
|
||||
"Node started callback failed: {}",
|
||||
format_exception_for_log(e, log_full_message=d),
|
||||
)
|
||||
|
||||
async def process_node(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node: MessageNode,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a single node and then check the queue."""
|
||||
if node.state == MessageState.ERROR:
|
||||
logger.info(
|
||||
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
||||
)
|
||||
await self._process_next(tree, processor)
|
||||
return
|
||||
|
||||
try:
|
||||
await processor(node.node_id, node)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task for node {node.node_id} was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
d = get_settings().log_messaging_error_details
|
||||
logger.error(
|
||||
"Error processing node {}: {}",
|
||||
node.node_id,
|
||||
format_exception_for_log(e, log_full_message=d),
|
||||
)
|
||||
await tree.update_state(
|
||||
node.node_id,
|
||||
MessageState.ERROR,
|
||||
error_message=get_user_facing_error_message(e),
|
||||
)
|
||||
finally:
|
||||
async with tree.with_lock():
|
||||
tree.clear_current_node()
|
||||
await self._process_next(tree, processor)
|
||||
|
||||
async def _process_next(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process the next message in queue, if any."""
|
||||
next_node_id = None
|
||||
async with tree.with_lock():
|
||||
next_node_id = await tree.dequeue()
|
||||
|
||||
if not next_node_id:
|
||||
tree.set_processing_state(None, False)
|
||||
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
|
||||
return
|
||||
|
||||
tree.set_processing_state(next_node_id, True)
|
||||
logger.info(f"Processing next queued node {next_node_id}")
|
||||
|
||||
node = tree.get_node(next_node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
|
||||
if next_node_id:
|
||||
await self._notify_node_started(tree, next_node_id)
|
||||
await self._notify_queue_updated(tree)
|
||||
|
||||
async def enqueue_and_start(
|
||||
self,
|
||||
tree: MessageTree,
|
||||
node_id: str,
|
||||
processor: Callable[[str, MessageNode], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Enqueue a node or start processing immediately.
|
||||
|
||||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
async with tree.with_lock():
|
||||
if tree.is_processing:
|
||||
tree.put_queue_unlocked(node_id)
|
||||
queue_size = tree.get_queue_size()
|
||||
logger.info(f"Queued node {node_id}, position {queue_size}")
|
||||
return True
|
||||
else:
|
||||
tree.set_processing_state(node_id, True)
|
||||
|
||||
node = tree.get_node(node_id)
|
||||
if node:
|
||||
tree.set_current_task(
|
||||
asyncio.create_task(self.process_node(tree, node, processor))
|
||||
)
|
||||
return False
|
||||
|
||||
def cancel_current(self, tree: MessageTree) -> bool:
|
||||
"""Cancel the currently running task in a tree."""
|
||||
return tree.cancel_current_task()
|
||||
|
||||
|
||||
class TreeQueueManager:
|
||||
"""
|
||||
Manages multiple message trees. Facade that coordinates components.
|
||||
Manages multiple message trees: index + async processing.
|
||||
|
||||
Each new conversation creates a new tree.
|
||||
Replies to existing messages add nodes to existing trees.
|
||||
|
||||
Components:
|
||||
- TreeRepository: Data access layer
|
||||
- TreeQueueProcessor: Async queue processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -33,7 +348,7 @@ class TreeQueueManager:
|
|||
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
|
||||
| None = None,
|
||||
_repository: TreeRepository | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self._repository = _repository or TreeRepository()
|
||||
self._processor = TreeQueueProcessor(
|
||||
queue_update_callback=queue_update_callback,
|
||||
|
|
@ -101,7 +416,6 @@ class TreeQueueManager:
|
|||
if not tree:
|
||||
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
|
||||
|
||||
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
|
||||
node = await tree.add_node(
|
||||
node_id=node_id,
|
||||
incoming=incoming,
|
||||
|
|
@ -228,7 +542,6 @@ class TreeQueueManager:
|
|||
|
||||
cleanup_count = 0
|
||||
async with tree.with_lock():
|
||||
# 1. Cancel running task
|
||||
if tree.cancel_current_task():
|
||||
current_id = tree.current_node_id
|
||||
if current_id:
|
||||
|
|
@ -240,12 +553,10 @@ class TreeQueueManager:
|
|||
tree.set_node_error_sync(node, "Cancelled by user")
|
||||
cancelled_nodes.append(node)
|
||||
|
||||
# 2. Drain queue and mark nodes as cancelled
|
||||
queue_nodes = tree.drain_queue_and_mark_cancelled()
|
||||
cancelled_nodes.extend(queue_nodes)
|
||||
cancelled_ids = {n.node_id for n in cancelled_nodes}
|
||||
|
||||
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
|
||||
for node in tree.all_nodes():
|
||||
if (
|
||||
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
||||
|
|
@ -269,10 +580,6 @@ class TreeQueueManager:
|
|||
"""
|
||||
Cancel a single node (queued or in-progress) without affecting other nodes.
|
||||
|
||||
- If the node is currently running, cancels the current asyncio task.
|
||||
- If the node is queued, removes it from the queue.
|
||||
- Marks the node as ERROR with "Cancelled by user".
|
||||
|
||||
Returns:
|
||||
List containing the cancelled node if it was cancellable, else empty list.
|
||||
"""
|
||||
|
|
@ -351,8 +658,6 @@ class TreeQueueManager:
|
|||
async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]:
|
||||
"""
|
||||
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
|
||||
|
||||
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
|
|
@ -435,3 +740,10 @@ class TreeQueueManager:
|
|||
node_started_callback=node_started_callback,
|
||||
_repository=TreeRepository.from_dict(data),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TreeQueueManager",
|
||||
"TreeQueueProcessor",
|
||||
"TreeRepository",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,186 +0,0 @@
|
|||
"""Repository for message tree data access.
|
||||
|
||||
Provides data access layer for managing trees and node mappings.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .data import MessageNode, MessageState, MessageTree
|
||||
|
||||
|
||||
class TreeRepository:
|
||||
"""
|
||||
Repository for message tree data access.
|
||||
|
||||
Manages the storage and lookup of trees and node-to-tree mappings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._trees: dict[str, MessageTree] = {} # root_id -> tree
|
||||
self._node_to_tree: dict[str, str] = {} # node_id -> root_id
|
||||
|
||||
def get_tree(self, root_id: str) -> MessageTree | None:
|
||||
"""Get a tree by its root ID."""
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_tree_for_node(self, node_id: str) -> MessageTree | None:
|
||||
"""Get the tree containing a given node."""
|
||||
root_id = self._node_to_tree.get(node_id)
|
||||
if not root_id:
|
||||
return None
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_node(self, node_id: str) -> MessageNode | None:
|
||||
"""Get a node from any tree."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_node(node_id) if tree else None
|
||||
|
||||
def add_tree(self, root_id: str, tree: MessageTree) -> None:
|
||||
"""Add a new tree to the repository."""
|
||||
self._trees[root_id] = tree
|
||||
self._node_to_tree[root_id] = root_id
|
||||
logger.debug("TREE_REPO: add_tree root_id={}", root_id)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree."""
|
||||
self._node_to_tree[node_id] = root_id
|
||||
logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node is registered in any tree."""
|
||||
return node_id in self._node_to_tree
|
||||
|
||||
def tree_count(self) -> int:
|
||||
"""Get the number of trees in the repository."""
|
||||
return len(self._trees)
|
||||
|
||||
def is_tree_busy(self, root_id: str) -> bool:
|
||||
"""Check if a tree is currently processing."""
|
||||
tree = self._trees.get(root_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def is_node_tree_busy(self, node_id: str) -> bool:
|
||||
"""Check if the tree containing a node is busy."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.is_processing if tree else False
|
||||
|
||||
def get_queue_size(self, node_id: str) -> int:
|
||||
"""Get queue size for the tree containing a node."""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
return tree.get_queue_size() if tree else 0
|
||||
|
||||
def resolve_parent_node_id(self, msg_id: str) -> str | None:
|
||||
"""
|
||||
Resolve a message ID to the actual parent node ID.
|
||||
|
||||
Handles the case where msg_id is a status message ID
|
||||
(which maps to the tree but isn't an actual node).
|
||||
|
||||
Returns:
|
||||
The node_id to use as parent, or None if not found
|
||||
"""
|
||||
tree = self.get_tree_for_node(msg_id)
|
||||
if not tree:
|
||||
return None
|
||||
|
||||
# Check if msg_id is an actual node
|
||||
if tree.has_node(msg_id):
|
||||
return msg_id
|
||||
|
||||
# Otherwise, it might be a status message - find the owning node
|
||||
node = tree.find_node_by_status_message(msg_id)
|
||||
if node:
|
||||
return node.node_id
|
||||
|
||||
return None
|
||||
|
||||
def get_pending_children(self, node_id: str) -> list[MessageNode]:
|
||||
"""
|
||||
Get all pending child nodes (recursively) of a given node.
|
||||
|
||||
Used for error propagation - when a node fails, its pending
|
||||
children should also be marked as failed.
|
||||
"""
|
||||
tree = self.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
pending: list[MessageNode] = []
|
||||
stack = [node_id]
|
||||
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
node = tree.get_node(current_id)
|
||||
if not node:
|
||||
continue
|
||||
for child_id in node.children_ids:
|
||||
child = tree.get_node(child_id)
|
||||
if child and child.state == MessageState.PENDING:
|
||||
pending.append(child)
|
||||
stack.append(child_id)
|
||||
|
||||
return pending
|
||||
|
||||
def all_trees(self) -> list[MessageTree]:
|
||||
"""Get all trees in the repository."""
|
||||
return list(self._trees.values())
|
||||
|
||||
def tree_ids(self) -> list[str]:
|
||||
"""Get all tree root IDs."""
|
||||
return list(self._trees.keys())
|
||||
|
||||
def unregister_nodes(self, node_ids: list[str]) -> None:
|
||||
"""Remove node IDs from the node-to-tree mapping."""
|
||||
for nid in node_ids:
|
||||
self._node_to_tree.pop(nid, None)
|
||||
|
||||
def remove_tree(self, root_id: str) -> MessageTree | None:
|
||||
"""
|
||||
Remove a tree and all its node mappings from the repository.
|
||||
|
||||
Returns:
|
||||
The removed tree, or None if not found.
|
||||
"""
|
||||
tree = self._trees.pop(root_id, None)
|
||||
if not tree:
|
||||
return None
|
||||
for node in tree.all_nodes():
|
||||
self._node_to_tree.pop(node.node_id, None)
|
||||
logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
|
||||
return tree
|
||||
|
||||
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
|
||||
"""Get all message IDs (incoming + status) for a given platform/chat.
|
||||
|
||||
Note: O(total_nodes) scan. Acceptable because this is only called
|
||||
from /clear (user-initiated, infrequent).
|
||||
"""
|
||||
msg_ids: set[str] = set()
|
||||
for tree in self._trees.values():
|
||||
for node in tree.all_nodes():
|
||||
if str(node.incoming.platform) == str(platform) and str(
|
||||
node.incoming.chat_id
|
||||
) == str(chat_id):
|
||||
if node.incoming.message_id is not None:
|
||||
msg_ids.add(str(node.incoming.message_id))
|
||||
if node.status_message_id:
|
||||
msg_ids.add(str(node.status_message_id))
|
||||
return msg_ids
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return {
|
||||
"trees": {rid: tree.to_dict() for rid, tree in self._trees.items()},
|
||||
"node_to_tree": self._node_to_tree.copy(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> TreeRepository:
|
||||
"""Deserialize from dictionary."""
|
||||
from .data import MessageTree
|
||||
|
||||
repo = cls()
|
||||
for root_id, tree_data in data.get("trees", {}).items():
|
||||
repo._trees[root_id] = MessageTree.from_dict(tree_data)
|
||||
repo._node_to_tree = data.get("node_to_tree", {})
|
||||
return repo
|
||||
101
messaging/ui_updates.py
Normal file
101
messaging/ui_updates.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Throttled platform UI updates driven by transcript rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .platforms.base import MessagingPlatform
|
||||
from .safe_diagnostics import format_exception_for_log
|
||||
from .transcript import RenderCtx, TranscriptBuffer
|
||||
|
||||
|
||||
class ThrottledTranscriptEditor:
|
||||
"""Rate-limited status message edits from a growing transcript."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
platform: MessagingPlatform,
|
||||
parse_mode: str | None,
|
||||
get_limit_chars: Callable[[], int],
|
||||
transcript: TranscriptBuffer,
|
||||
render_ctx: RenderCtx,
|
||||
node_id: str,
|
||||
chat_id: str,
|
||||
status_msg_id: str,
|
||||
debug_platform_edits: bool,
|
||||
log_messaging_error_details: bool = False,
|
||||
) -> None:
|
||||
self._platform = platform
|
||||
self._parse_mode = parse_mode
|
||||
self._get_limit_chars = get_limit_chars
|
||||
self._transcript = transcript
|
||||
self._render_ctx = render_ctx
|
||||
self._node_id = node_id
|
||||
self._chat_id = chat_id
|
||||
self._status_msg_id = status_msg_id
|
||||
self._debug_platform_edits = debug_platform_edits
|
||||
self._log_messaging_error_details = log_messaging_error_details
|
||||
self._last_ui_update = 0.0
|
||||
self._last_displayed_text: str | None = None
|
||||
self._last_status: str | None = None
|
||||
|
||||
@property
|
||||
def last_status(self) -> str | None:
|
||||
return self._last_status
|
||||
|
||||
async def update(self, status: str | None = None, *, force: bool = False) -> None:
|
||||
"""Render transcript + optional status line and edit the platform message."""
|
||||
now = time.time()
|
||||
if not force and now - self._last_ui_update < 1.0:
|
||||
return
|
||||
|
||||
self._last_ui_update = now
|
||||
if status is not None:
|
||||
self._last_status = status
|
||||
try:
|
||||
display = self._transcript.render(
|
||||
self._render_ctx,
|
||||
limit_chars=self._get_limit_chars(),
|
||||
status=status,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Transcript render failed for node {}: {}",
|
||||
self._node_id,
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
return
|
||||
if display and display != self._last_displayed_text:
|
||||
logger.debug(
|
||||
"PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
|
||||
self._node_id,
|
||||
self._chat_id,
|
||||
self._status_msg_id,
|
||||
bool(force),
|
||||
status,
|
||||
len(display),
|
||||
)
|
||||
if self._debug_platform_edits:
|
||||
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
|
||||
self._last_displayed_text = display
|
||||
try:
|
||||
await self._platform.queue_edit_message(
|
||||
self._chat_id,
|
||||
self._status_msg_id,
|
||||
display,
|
||||
parse_mode=self._parse_mode,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to update platform for node {}: {}",
|
||||
self._node_id,
|
||||
format_exception_for_log(
|
||||
e, log_full_message=self._log_messaging_error_details
|
||||
),
|
||||
)
|
||||
|
|
@ -2,19 +2,32 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import get_user_facing_error_message
|
||||
from config.constants import (
|
||||
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES,
|
||||
)
|
||||
from core.anthropic import iter_provider_stream_error_sse_events
|
||||
from core.anthropic.emitted_sse_tracker import EmittedNativeSseTracker
|
||||
from core.anthropic.native_messages_request import (
|
||||
build_base_native_anthropic_request_body,
|
||||
)
|
||||
from core.anthropic.native_sse_block_policy import (
|
||||
NativeSseBlockPolicyState,
|
||||
transform_native_sse_block_event,
|
||||
)
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.error_mapping import map_error
|
||||
from providers.error_mapping import (
|
||||
map_error,
|
||||
user_visible_message_for_mapped_provider_error,
|
||||
)
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
ANTHROPIC_DEFAULT_MAX_TOKENS = 81920
|
||||
StreamChunkMode = Literal["line", "event"]
|
||||
|
||||
|
||||
|
|
@ -64,25 +77,11 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
) -> dict:
|
||||
"""Build a native Anthropic request body."""
|
||||
thinking_enabled = self._is_thinking_enabled(request, thinking_enabled)
|
||||
body = request.model_dump(exclude_none=True)
|
||||
|
||||
body.pop("extra_body", None)
|
||||
body.pop("original_model", None)
|
||||
body.pop("resolved_provider_model", None)
|
||||
|
||||
if "thinking" in body:
|
||||
thinking_cfg = body.pop("thinking")
|
||||
if thinking_enabled and isinstance(thinking_cfg, dict):
|
||||
thinking_payload = {"type": "enabled"}
|
||||
budget_tokens = thinking_cfg.get("budget_tokens")
|
||||
if isinstance(budget_tokens, int):
|
||||
thinking_payload["budget_tokens"] = budget_tokens
|
||||
body["thinking"] = thinking_payload
|
||||
|
||||
if "max_tokens" not in body:
|
||||
body["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS
|
||||
|
||||
return body
|
||||
return build_base_native_anthropic_request_body(
|
||||
request,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
|
||||
async def _send_stream_request(self, body: dict) -> httpx.Response:
|
||||
"""Create a streaming messages response."""
|
||||
|
|
@ -97,30 +96,68 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
async def _raise_for_status(
|
||||
self, response: httpx.Response, *, req_tag: str
|
||||
) -> None:
|
||||
"""Raise for non-200 responses after logging the upstream body."""
|
||||
"""Raise for non-200 responses after logging safe metadata (or capped body if opted in)."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as error:
|
||||
response_text = await self._read_error_body(response)
|
||||
if response_text:
|
||||
if self._config.log_api_error_tracebacks:
|
||||
preview, truncated = await self._read_error_body_preview(
|
||||
response, NATIVE_MESSAGES_ERROR_BODY_LOG_CAP_BYTES
|
||||
)
|
||||
if preview:
|
||||
text = preview.decode("utf-8", errors="replace")
|
||||
logger.error(
|
||||
"{}_ERROR:{} HTTP {}: {}",
|
||||
"{}_ERROR:{} HTTP {} body_preview_bytes={} truncated={}: {}",
|
||||
self._provider_name,
|
||||
req_tag,
|
||||
response.status_code,
|
||||
response_text,
|
||||
len(preview),
|
||||
truncated,
|
||||
text,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"{}_ERROR:{} HTTP {} (empty error body)",
|
||||
self._provider_name,
|
||||
req_tag,
|
||||
response.status_code,
|
||||
)
|
||||
else:
|
||||
cl = response.headers.get("content-length", "").strip()
|
||||
extra = f" content_length_declared={cl}" if cl.isdigit() else ""
|
||||
logger.error(
|
||||
"{}_ERROR:{} HTTP {}{}",
|
||||
self._provider_name,
|
||||
req_tag,
|
||||
response.status_code,
|
||||
extra,
|
||||
)
|
||||
raise error
|
||||
|
||||
async def _read_error_body(self, response: httpx.Response) -> str:
|
||||
"""Read a response body for diagnostics."""
|
||||
aread = getattr(response, "aread", None)
|
||||
if aread is None:
|
||||
return ""
|
||||
body = await aread()
|
||||
if isinstance(body, bytes):
|
||||
return body.decode("utf-8", errors="replace")
|
||||
return str(body)
|
||||
async def _read_error_body_preview(
|
||||
self, response: httpx.Response, max_bytes: int
|
||||
) -> tuple[bytes, bool]:
|
||||
"""Read at most ``max_bytes`` from the error body for logging. Returns (preview, truncated)."""
|
||||
if max_bytes <= 0:
|
||||
return b"", False
|
||||
received = 0
|
||||
parts: list[bytes] = []
|
||||
truncated = False
|
||||
async for chunk in response.aiter_bytes(chunk_size=65_536):
|
||||
if received >= max_bytes:
|
||||
truncated = True
|
||||
break
|
||||
remaining = max_bytes - received
|
||||
take = chunk if len(chunk) <= remaining else chunk[:remaining]
|
||||
if take:
|
||||
parts.append(take)
|
||||
received += len(take)
|
||||
if len(chunk) > len(take):
|
||||
truncated = True
|
||||
break
|
||||
if received >= max_bytes:
|
||||
break
|
||||
return (b"".join(parts), truncated)
|
||||
|
||||
async def _iter_sse_lines(self, response: httpx.Response) -> AsyncIterator[str]:
|
||||
"""Yield raw SSE line chunks preserving local provider behavior."""
|
||||
|
|
@ -145,6 +182,8 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
|
||||
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
|
||||
"""Return per-stream provider state for event transformation."""
|
||||
if self.stream_chunk_mode == "line":
|
||||
return NativeSseBlockPolicyState()
|
||||
return None
|
||||
|
||||
def _transform_stream_event(
|
||||
|
|
@ -155,6 +194,10 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
thinking_enabled: bool,
|
||||
) -> str | None:
|
||||
"""Transform or drop a grouped SSE event before yielding it downstream."""
|
||||
if isinstance(state, NativeSseBlockPolicyState):
|
||||
return transform_native_sse_block_event(
|
||||
event, state, thinking_enabled=thinking_enabled
|
||||
)
|
||||
return event
|
||||
|
||||
def _format_error_message(self, base_message: str, request_id: str | None) -> str:
|
||||
|
|
@ -166,14 +209,10 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
def _get_error_message(self, error: Exception, request_id: str | None) -> str:
|
||||
"""Map an exception into a user-facing provider error message."""
|
||||
mapped_error = map_error(error, rate_limiter=self._global_rate_limiter)
|
||||
if getattr(mapped_error, "status_code", None) == 405:
|
||||
base_message = (
|
||||
f"Upstream provider {self._provider_name} rejected the request method "
|
||||
"or endpoint (HTTP 405)."
|
||||
)
|
||||
else:
|
||||
base_message = get_user_facing_error_message(
|
||||
mapped_error, read_timeout_s=self._config.http_read_timeout
|
||||
base_message = user_visible_message_for_mapped_provider_error(
|
||||
mapped_error,
|
||||
provider_name=self._provider_name,
|
||||
read_timeout_s=self._config.http_read_timeout,
|
||||
)
|
||||
return self._format_error_message(base_message, request_id)
|
||||
|
||||
|
|
@ -185,12 +224,14 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
error_message: str,
|
||||
sent_any_event: bool,
|
||||
) -> Iterator[str]:
|
||||
"""Emit a native Anthropic error event."""
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {"type": "api_error", "message": error_message},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n"
|
||||
"""Emit the same Anthropic message lifecycle used by OpenAI-compat providers."""
|
||||
yield from iter_provider_stream_error_sse_events(
|
||||
request=request,
|
||||
input_tokens=input_tokens,
|
||||
error_message=error_message,
|
||||
sent_any_event=sent_any_event,
|
||||
log_raw_sse_events=self._config.log_raw_sse_events,
|
||||
)
|
||||
|
||||
async def _iter_stream_chunks(
|
||||
self,
|
||||
|
|
@ -200,6 +241,21 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
thinking_enabled: bool,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield stream chunks according to the provider's observable chunk shape."""
|
||||
if self.stream_chunk_mode == "line" and isinstance(
|
||||
state, NativeSseBlockPolicyState
|
||||
):
|
||||
async for event in self._iter_sse_events(response):
|
||||
output_event = self._transform_stream_event(
|
||||
event,
|
||||
state,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
if output_event is None:
|
||||
continue
|
||||
for line in output_event.splitlines(keepends=True):
|
||||
yield line
|
||||
return
|
||||
|
||||
if self.stream_chunk_mode == "line":
|
||||
async for chunk in self._iter_sse_lines(response):
|
||||
yield chunk
|
||||
|
|
@ -240,15 +296,28 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
response: httpx.Response | None = None
|
||||
sent_any_event = False
|
||||
state = self._new_stream_state(request, thinking_enabled=thinking_enabled)
|
||||
emitted_tracker = EmittedNativeSseTracker()
|
||||
|
||||
async with self._global_rate_limiter.concurrency_slot():
|
||||
try:
|
||||
response = await self._global_rate_limiter.execute_with_retry(
|
||||
self._send_stream_request, body
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
await self._raise_for_status(response, req_tag=req_tag)
|
||||
async def _validated_stream_send() -> httpx.Response:
|
||||
"""Send request; raise inside retry loop on 429 so rate limiter can backoff."""
|
||||
send_response = await self._send_stream_request(body)
|
||||
if send_response.status_code == 429:
|
||||
await send_response.aclose()
|
||||
send_response.raise_for_status()
|
||||
if send_response.status_code != 200:
|
||||
try:
|
||||
await self._raise_for_status(send_response, req_tag=req_tag)
|
||||
finally:
|
||||
if not send_response.is_closed:
|
||||
await send_response.aclose()
|
||||
return send_response
|
||||
|
||||
response = await self._global_rate_limiter.execute_with_retry(
|
||||
_validated_stream_send
|
||||
)
|
||||
|
||||
async for chunk in self._iter_stream_chunks(
|
||||
response,
|
||||
|
|
@ -256,12 +325,12 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
thinking_enabled=thinking_enabled,
|
||||
):
|
||||
sent_any_event = True
|
||||
emitted_tracker.feed(chunk)
|
||||
yield chunk
|
||||
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
|
||||
)
|
||||
if not isinstance(error, httpx.HTTPStatusError):
|
||||
self._log_stream_transport_error(tag, req_tag, error)
|
||||
error_message = self._get_error_message(error, request_id)
|
||||
|
||||
if response is not None and not response.is_closed:
|
||||
|
|
@ -273,11 +342,22 @@ class AnthropicMessagesTransport(BaseProvider):
|
|||
type(error).__name__,
|
||||
req_tag,
|
||||
)
|
||||
if sent_any_event:
|
||||
for event in emitted_tracker.iter_close_unclosed_blocks():
|
||||
yield event
|
||||
for event in emitted_tracker.iter_midstream_error_tail(
|
||||
error_message,
|
||||
request=request,
|
||||
input_tokens=input_tokens,
|
||||
log_raw_sse_events=self._config.log_raw_sse_events,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
for event in self._emit_error_events(
|
||||
request=request,
|
||||
input_tokens=input_tokens,
|
||||
error_message=error_message,
|
||||
sent_any_event=sent_any_event,
|
||||
sent_any_event=False,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Configuration for a provider.
|
||||
|
|
@ -21,9 +23,11 @@ class ProviderConfig(BaseModel):
|
|||
max_concurrency: int = 5
|
||||
http_read_timeout: float = 300.0
|
||||
http_write_timeout: float = 10.0
|
||||
http_connect_timeout: float = 2.0
|
||||
http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT
|
||||
enable_thinking: bool = True
|
||||
proxy: str = ""
|
||||
log_raw_sse_events: bool = False
|
||||
log_api_error_tracebacks: bool = False
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
|
|
@ -61,6 +65,42 @@ class BaseProvider(ABC):
|
|||
request_enabled = bool(enabled)
|
||||
return config_enabled and request_enabled
|
||||
|
||||
def preflight_stream(
|
||||
self, request: Any, *, thinking_enabled: bool | None = None
|
||||
) -> None:
|
||||
"""Eagerly validate/build the upstream request before opening an SSE stream.
|
||||
|
||||
Subclasses with ``_build_request_body`` (OpenAI and native) raise
|
||||
:class:`providers.exceptions.InvalidRequestError` on conversion failures.
|
||||
"""
|
||||
build = getattr(self, "_build_request_body", None)
|
||||
if build is None:
|
||||
return
|
||||
build(request, thinking_enabled=thinking_enabled)
|
||||
|
||||
def _log_stream_transport_error(
|
||||
self, tag: str, req_tag: str, error: Exception
|
||||
) -> None:
|
||||
"""Log streaming transport failures (metadata-only unless verbose is enabled)."""
|
||||
from loguru import logger
|
||||
|
||||
if self._config.log_api_error_tracebacks:
|
||||
logger.error(
|
||||
"{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
|
||||
)
|
||||
return
|
||||
response = getattr(error, "response", None)
|
||||
status_code = (
|
||||
getattr(response, "status_code", None) if response is not None else None
|
||||
)
|
||||
logger.error(
|
||||
"{}_ERROR:{} exc_type={} http_status={}",
|
||||
tag,
|
||||
req_tag,
|
||||
type(error).__name__,
|
||||
status_code,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Release any resources held by this provider."""
|
||||
|
|
@ -75,5 +115,7 @@ class BaseProvider(ABC):
|
|||
thinking_enabled: bool | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream response in Anthropic SSE format."""
|
||||
# Typing: abstract async generators need a yield for AsyncIterator[str]
|
||||
# inference; this branch is never executed.
|
||||
if False:
|
||||
yield "" # Required for ty/mypy to accept abstract async generator
|
||||
yield ""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""DeepSeek provider exports."""
|
||||
|
||||
from .client import DEEPSEEK_BASE_URL, DeepSeekProvider
|
||||
from providers.defaults import DEEPSEEK_DEFAULT_BASE
|
||||
|
||||
__all__ = ["DEEPSEEK_BASE_URL", "DeepSeekProvider"]
|
||||
from .client import DeepSeekProvider
|
||||
|
||||
__all__ = ["DEEPSEEK_DEFAULT_BASE", "DeepSeekProvider"]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from typing import Any
|
||||
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import DEEPSEEK_BASE_URL
|
||||
from providers.defaults import DEEPSEEK_DEFAULT_BASE
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
|
@ -16,7 +16,7 @@ class DeepSeekProvider(OpenAIChatTransport):
|
|||
super().__init__(
|
||||
config,
|
||||
provider_name="DEEPSEEK",
|
||||
base_url=config.base_url or DEEPSEEK_BASE_URL,
|
||||
base_url=config.base_url or DEEPSEEK_DEFAULT_BASE,
|
||||
api_key=config.api_key,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from typing import Any
|
|||
from loguru import logger
|
||||
|
||||
from core.anthropic import build_base_request_body
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
|
|
@ -14,10 +16,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
|||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
try:
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
include_reasoning_content=True,
|
||||
include_thinking=thinking_enabled,
|
||||
include_reasoning_content=thinking_enabled,
|
||||
)
|
||||
except OpenAIConversionError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
|
||||
extra_body: dict[str, Any] = {}
|
||||
request_extra = getattr(request_data, "extra_body", None)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,19 @@
|
|||
"""Default upstream base URLs and shared provider constants.
|
||||
"""Re-exports default upstream base URLs from the config provider catalog."""
|
||||
|
||||
Adapters and :mod:`providers.registry` import from here to avoid duplicating
|
||||
literals and to keep ``providers.registry`` free of per-adapter eager imports.
|
||||
"""
|
||||
from config.provider_catalog import (
|
||||
DEEPSEEK_DEFAULT_BASE,
|
||||
LLAMACPP_DEFAULT_BASE,
|
||||
LMSTUDIO_DEFAULT_BASE,
|
||||
NVIDIA_NIM_DEFAULT_BASE,
|
||||
OLLAMA_DEFAULT_BASE,
|
||||
OPENROUTER_DEFAULT_BASE,
|
||||
)
|
||||
|
||||
# OpenAI-compatible chat (NIM, DeepSeek) and local/native provider endpoints
|
||||
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
|
||||
DEEPSEEK_DEFAULT_BASE = "https://api.deepseek.com"
|
||||
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
|
||||
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
|
||||
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
|
||||
OLLAMA_DEFAULT_BASE = "http://localhost:11434"
|
||||
|
||||
# Backward-compatible names used by existing adapter modules
|
||||
NVIDIA_NIM_BASE_URL = NVIDIA_NIM_DEFAULT_BASE
|
||||
DEEPSEEK_BASE_URL = DEEPSEEK_DEFAULT_BASE
|
||||
OPENROUTER_BASE_URL = OPENROUTER_DEFAULT_BASE
|
||||
LMSTUDIO_DEFAULT_BASE_URL = LMSTUDIO_DEFAULT_BASE
|
||||
LLAMACPP_DEFAULT_BASE_URL = LLAMACPP_DEFAULT_BASE
|
||||
OLLAMA_DEFAULT_BASE_URL = OLLAMA_DEFAULT_BASE
|
||||
__all__ = (
|
||||
"DEEPSEEK_DEFAULT_BASE",
|
||||
"LLAMACPP_DEFAULT_BASE",
|
||||
"LMSTUDIO_DEFAULT_BASE",
|
||||
"NVIDIA_NIM_DEFAULT_BASE",
|
||||
"OLLAMA_DEFAULT_BASE",
|
||||
"OPENROUTER_DEFAULT_BASE",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,10 +14,30 @@ from providers.exceptions import (
|
|||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
|
||||
def user_visible_message_for_mapped_provider_error(
|
||||
mapped: Exception,
|
||||
*,
|
||||
provider_name: str,
|
||||
read_timeout_s: float | None,
|
||||
) -> str:
|
||||
"""Return the user-visible string after :func:`map_error` (405 + mapped types)."""
|
||||
if getattr(mapped, "status_code", None) == 405:
|
||||
return (
|
||||
f"Upstream provider {provider_name} rejected the request method "
|
||||
"or endpoint (HTTP 405)."
|
||||
)
|
||||
return get_user_facing_error_message(mapped, read_timeout_s=read_timeout_s)
|
||||
|
||||
|
||||
def map_error(
|
||||
e: Exception, *, rate_limiter: GlobalRateLimiter | None = None
|
||||
) -> Exception:
|
||||
"""Map OpenAI or HTTPX exception to specific ProviderError."""
|
||||
"""Map OpenAI or HTTPX exception to specific ProviderError.
|
||||
|
||||
Streaming transports should pass their scoped limiter (``self._global_rate_limiter``)
|
||||
so reactive 429 handling applies to the correct provider. Tests may omit
|
||||
``rate_limiter`` to use the process-wide singleton.
|
||||
"""
|
||||
message = get_user_facing_error_message(e)
|
||||
limiter = rate_limiter or GlobalRateLimiter.get_instance()
|
||||
|
||||
|
|
|
|||
|
|
@ -90,8 +90,20 @@ class APIError(ProviderError):
|
|||
)
|
||||
|
||||
|
||||
class UnknownProviderTypeError(ValueError):
|
||||
class UnknownProviderTypeError(InvalidRequestError):
|
||||
"""Raised when ``provider_id`` is not registered in the provider map."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ServiceUnavailableError(ProviderError):
|
||||
"""Raised when the server is not ready (e.g. app lifespan did not wire state)."""
|
||||
|
||||
def __init__(self, message: str, raw_error: Any = None):
|
||||
super().__init__(
|
||||
message,
|
||||
status_code=503,
|
||||
error_type="api_error",
|
||||
raw_error=raw_error,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import LLAMACPP_DEFAULT_BASE_URL
|
||||
from providers.defaults import LLAMACPP_DEFAULT_BASE
|
||||
|
||||
|
||||
class LlamaCppProvider(AnthropicMessagesTransport):
|
||||
|
|
@ -12,5 +12,5 @@ class LlamaCppProvider(AnthropicMessagesTransport):
|
|||
super().__init__(
|
||||
config,
|
||||
provider_name="LLAMACPP",
|
||||
default_base_url=LLAMACPP_DEFAULT_BASE_URL,
|
||||
default_base_url=LLAMACPP_DEFAULT_BASE,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import LMSTUDIO_DEFAULT_BASE_URL
|
||||
from providers.defaults import LMSTUDIO_DEFAULT_BASE
|
||||
|
||||
|
||||
class LMStudioProvider(AnthropicMessagesTransport):
|
||||
|
|
@ -12,5 +12,5 @@ class LMStudioProvider(AnthropicMessagesTransport):
|
|||
super().__init__(
|
||||
config,
|
||||
provider_name="LMSTUDIO",
|
||||
default_base_url=LMSTUDIO_DEFAULT_BASE_URL,
|
||||
default_base_url=LMSTUDIO_DEFAULT_BASE,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""NVIDIA NIM provider package."""
|
||||
|
||||
from .client import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
|
||||
|
||||
__all__ = ["NVIDIA_NIM_BASE_URL", "NvidiaNimProvider"]
|
||||
from .client import NvidiaNimProvider
|
||||
|
||||
__all__ = ["NVIDIA_NIM_DEFAULT_BASE", "NvidiaNimProvider"]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from loguru import logger
|
|||
|
||||
from config.nim import NimSettings
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import NVIDIA_NIM_BASE_URL
|
||||
from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import (
|
||||
|
|
@ -25,7 +25,7 @@ class NvidiaNimProvider(OpenAIChatTransport):
|
|||
super().__init__(
|
||||
config,
|
||||
provider_name="NIM",
|
||||
base_url=config.base_url or NVIDIA_NIM_BASE_URL,
|
||||
base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE,
|
||||
api_key=config.api_key,
|
||||
)
|
||||
self._nim_settings = nim_settings
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Request builder for NVIDIA NIM provider."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -7,6 +8,42 @@ from loguru import logger
|
|||
|
||||
from config.nim import NimSettings
|
||||
from core.anthropic import build_base_request_body, set_if_not_none
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def _clone_strip_extra_body(
|
||||
body: dict[str, Any],
|
||||
strip: Callable[[dict[str, Any]], bool],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Deep-clone ``body`` and remove fields via ``strip`` on ``extra_body`` only.
|
||||
|
||||
Returns ``None`` when there is no ``extra_body`` dict or ``strip`` reports no change.
|
||||
"""
|
||||
cloned_body = deepcopy(body)
|
||||
extra_body = cloned_body.get("extra_body")
|
||||
if not isinstance(extra_body, dict):
|
||||
return None
|
||||
if not strip(extra_body):
|
||||
return None
|
||||
if not extra_body:
|
||||
cloned_body.pop("extra_body", None)
|
||||
return cloned_body
|
||||
|
||||
|
||||
def _strip_reasoning_budget_fields(extra_body: dict[str, Any]) -> bool:
|
||||
removed = extra_body.pop("reasoning_budget", None) is not None
|
||||
chat_template_kwargs = extra_body.get("chat_template_kwargs")
|
||||
if (
|
||||
isinstance(chat_template_kwargs, dict)
|
||||
and chat_template_kwargs.pop("reasoning_budget", None) is not None
|
||||
):
|
||||
removed = True
|
||||
return removed
|
||||
|
||||
|
||||
def _strip_chat_template_field(extra_body: dict[str, Any]) -> bool:
|
||||
return extra_body.pop("chat_template", None) is not None
|
||||
|
||||
|
||||
def _set_extra(
|
||||
|
|
@ -23,43 +60,12 @@ def _set_extra(
|
|||
|
||||
def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Clone a request body and strip only reasoning_budget fields."""
|
||||
cloned_body = deepcopy(body)
|
||||
extra_body = cloned_body.get("extra_body")
|
||||
if not isinstance(extra_body, dict):
|
||||
return None
|
||||
|
||||
removed = extra_body.pop("reasoning_budget", None) is not None
|
||||
|
||||
chat_template_kwargs = extra_body.get("chat_template_kwargs")
|
||||
if (
|
||||
isinstance(chat_template_kwargs, dict)
|
||||
and chat_template_kwargs.pop("reasoning_budget", None) is not None
|
||||
):
|
||||
removed = True
|
||||
|
||||
if not extra_body:
|
||||
cloned_body.pop("extra_body", None)
|
||||
|
||||
if not removed:
|
||||
return None
|
||||
|
||||
return cloned_body
|
||||
return _clone_strip_extra_body(body, _strip_reasoning_budget_fields)
|
||||
|
||||
|
||||
def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Clone a request body and strip only chat_template."""
|
||||
cloned_body = deepcopy(body)
|
||||
extra_body = cloned_body.get("extra_body")
|
||||
if not isinstance(extra_body, dict):
|
||||
return None
|
||||
|
||||
if extra_body.pop("chat_template", None) is None:
|
||||
return None
|
||||
|
||||
if not extra_body:
|
||||
cloned_body.pop("extra_body", None)
|
||||
|
||||
return cloned_body
|
||||
return _clone_strip_extra_body(body, _strip_chat_template_field)
|
||||
|
||||
|
||||
def build_request_body(
|
||||
|
|
@ -71,10 +77,13 @@ def build_request_body(
|
|||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
try:
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
include_thinking=thinking_enabled,
|
||||
)
|
||||
except OpenAIConversionError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
|
||||
# NIM-specific max_tokens: cap against nim.max_tokens
|
||||
max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None)
|
||||
|
|
|
|||
95
providers/nvidia_nim/voice.py
Normal file
95
providers/nvidia_nim/voice.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""NVIDIA NIM / Riva offline ASR for voice notes (provider-owned transport)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# NVIDIA NIM Whisper model mapping: (function_id, language_code)
|
||||
_NIM_ASR_MODEL_MAP: dict[str, tuple[str, str]] = {
|
||||
"nvidia/parakeet-ctc-0.6b-zh-tw": ("8473f56d-51ef-473c-bb26-efd4f5def2bf", "zh-TW"),
|
||||
"nvidia/parakeet-ctc-0.6b-zh-cn": ("9add5ef7-322e-47e0-ad7a-5653fb8d259b", "zh-CN"),
|
||||
# function-id from NVIDIA NIM API docs (parakeet-ctc-0.6b-es).
|
||||
"nvidia/parakeet-ctc-0.6b-es": ("a9eeee8f-b509-4712-b19d-194361fa5f31", "es-US"),
|
||||
"nvidia/parakeet-ctc-0.6b-vi": ("f3dff2bb-99f9-403d-a5f1-f574a757deb0", "vi-VN"),
|
||||
"nvidia/parakeet-ctc-1.1b-asr": ("1598d209-5e27-4d3c-8079-4751568b1081", "en-US"),
|
||||
"nvidia/parakeet-ctc-0.6b-asr": ("d8dd4e9b-fbf5-4fb0-9dba-8cf436c8d965", "en-US"),
|
||||
"nvidia/parakeet-1.1b-rnnt-multilingual-asr": (
|
||||
"71203149-d3b7-4460-8231-1be2543a1fca",
|
||||
"",
|
||||
),
|
||||
"openai/whisper-large-v3": ("b702f636-f60c-4a3d-a6f4-f3568c13bd7d", "multi"),
|
||||
}
|
||||
|
||||
_RIVA_SERVER = "grpc.nvcf.nvidia.com:443"
|
||||
|
||||
|
||||
def transcribe_audio_file(
|
||||
file_path: Path,
|
||||
model: str,
|
||||
*,
|
||||
api_key: str,
|
||||
) -> str:
|
||||
"""Transcribe audio using NVIDIA NIM / Riva gRPC (offline recognition).
|
||||
|
||||
Args:
|
||||
file_path: Path to encoded audio bytes readable by Riva.
|
||||
model: Hugging Face-style NIM model id (see ``_NIM_ASR_MODEL_MAP``).
|
||||
api_key: NVIDIA API key (Bearer token); must be non-empty.
|
||||
|
||||
Returns:
|
||||
Transcript text, or ``(no speech detected)`` when empty.
|
||||
"""
|
||||
key = (api_key or "").strip()
|
||||
if not key:
|
||||
raise ValueError(
|
||||
"NVIDIA NIM transcription requires a non-empty nvidia_nim_api_key "
|
||||
"(configure NVIDIA_NIM_API_KEY or pass api_key explicitly)."
|
||||
)
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"NVIDIA NIM transcription requires the voice extra. "
|
||||
"Install with: uv sync --extra voice"
|
||||
) from e
|
||||
|
||||
model_config = _NIM_ASR_MODEL_MAP.get(model)
|
||||
if not model_config:
|
||||
raise ValueError(
|
||||
f"No NVIDIA NIM config found for model: {model}. "
|
||||
f"Supported models: {', '.join(_NIM_ASR_MODEL_MAP.keys())}"
|
||||
)
|
||||
function_id, language_code = model_config
|
||||
|
||||
auth = riva.client.Auth(
|
||||
use_ssl=True,
|
||||
uri=_RIVA_SERVER,
|
||||
metadata_args=[
|
||||
["function-id", function_id],
|
||||
["authorization", f"Bearer {key}"],
|
||||
],
|
||||
)
|
||||
|
||||
asr_service = riva.client.ASRService(auth)
|
||||
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=language_code,
|
||||
max_alternatives=1,
|
||||
verbatim_transcripts=True,
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
response = asr_service.offline_recognize(data, config)
|
||||
|
||||
transcript = ""
|
||||
results = getattr(response, "results", None)
|
||||
if results and results[0].alternatives:
|
||||
transcript = results[0].alternatives[0].transcript
|
||||
|
||||
logger.debug(f"NIM transcription: {len(transcript)} chars")
|
||||
return transcript or "(no speech detected)"
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
"""Ollama provider package."""
|
||||
|
||||
from .client import OLLAMA_BASE_URL, OllamaProvider
|
||||
from providers.defaults import OLLAMA_DEFAULT_BASE
|
||||
|
||||
__all__ = ["OLLAMA_BASE_URL", "OllamaProvider"]
|
||||
from .client import OllamaProvider
|
||||
|
||||
__all__ = ["OLLAMA_DEFAULT_BASE", "OllamaProvider"]
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from providers.anthropic_messages import AnthropicMessagesTransport
|
|||
from providers.base import ProviderConfig
|
||||
from providers.defaults import OLLAMA_DEFAULT_BASE
|
||||
|
||||
OLLAMA_BASE_URL = OLLAMA_DEFAULT_BASE
|
||||
|
||||
|
||||
class OllamaProvider(AnthropicMessagesTransport):
|
||||
"""Ollama provider using native Anthropic Messages API."""
|
||||
|
|
@ -16,7 +14,7 @@ class OllamaProvider(AnthropicMessagesTransport):
|
|||
super().__init__(
|
||||
config,
|
||||
provider_name="OLLAMA",
|
||||
default_base_url=OLLAMA_BASE_URL,
|
||||
default_base_url=OLLAMA_DEFAULT_BASE,
|
||||
)
|
||||
self._api_key = config.api_key or "ollama"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""OpenRouter provider - Anthropic-compatible native transport."""
|
||||
|
||||
from .client import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||
from providers.defaults import OPENROUTER_DEFAULT_BASE
|
||||
|
||||
__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"]
|
||||
from .client import OpenRouterProvider
|
||||
|
||||
__all__ = ["OPENROUTER_DEFAULT_BASE", "OpenRouterProvider"]
|
||||
|
|
|
|||
|
|
@ -2,34 +2,25 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from core.anthropic import SSEBuilder, append_request_id
|
||||
from core.anthropic import append_request_id, iter_provider_stream_error_sse_events
|
||||
from core.anthropic.native_sse_block_policy import (
|
||||
NativeSseBlockPolicyState,
|
||||
is_terminal_openrouter_done_event,
|
||||
parse_native_sse_event,
|
||||
transform_native_sse_block_event,
|
||||
)
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import OPENROUTER_BASE_URL
|
||||
from providers.defaults import OPENROUTER_DEFAULT_BASE
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
_ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SSEFilterState:
|
||||
"""Track Anthropic content block index remapping while filtering thinking."""
|
||||
|
||||
next_index: int = 0
|
||||
index_map: dict[int, int] = field(default_factory=dict)
|
||||
dropped_indexes: set[int] = field(default_factory=set)
|
||||
open_block_types: dict[int, str] = field(default_factory=dict)
|
||||
closed_indexes: set[int] = field(default_factory=set)
|
||||
message_stopped: bool = False
|
||||
|
||||
|
||||
class OpenRouterProvider(AnthropicMessagesTransport):
|
||||
"""OpenRouter provider using the native Anthropic-compatible messages API."""
|
||||
|
||||
|
|
@ -39,7 +30,7 @@ class OpenRouterProvider(AnthropicMessagesTransport):
|
|||
super().__init__(
|
||||
config,
|
||||
provider_name="OPENROUTER",
|
||||
default_base_url=OPENROUTER_BASE_URL,
|
||||
default_base_url=OPENROUTER_DEFAULT_BASE,
|
||||
)
|
||||
|
||||
def _build_request_body(
|
||||
|
|
@ -60,163 +51,9 @@ class OpenRouterProvider(AnthropicMessagesTransport):
|
|||
"anthropic-version": _ANTHROPIC_VERSION,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_sse_event(event_name: str | None, data_text: str) -> str:
|
||||
"""Format an SSE event from its event name and data payload."""
|
||||
lines: list[str] = []
|
||||
if event_name:
|
||||
lines.append(f"event: {event_name}")
|
||||
lines.extend(f"data: {line}" for line in data_text.splitlines())
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
@staticmethod
|
||||
def _parse_sse_event(event: str) -> tuple[str | None, str]:
|
||||
"""Extract the event name and raw data payload from an SSE event."""
|
||||
event_name = None
|
||||
data_lines: list[str] = []
|
||||
for line in event.strip().splitlines():
|
||||
if line.startswith("event:"):
|
||||
event_name = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_lines.append(line[5:].lstrip())
|
||||
return event_name, "\n".join(data_lines)
|
||||
|
||||
@staticmethod
|
||||
def _is_terminal_done_event(event_name: str | None, data_text: str) -> bool:
|
||||
"""Return whether an event is OpenAI-style terminal noise."""
|
||||
return (event_name is None or event_name in {"data", "done"}) and (
|
||||
data_text.strip().upper() == "[DONE]"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _remap_index(
|
||||
payload: dict[str, Any], state: _SSEFilterState, *, create: bool
|
||||
) -> int | None:
|
||||
"""Return the downstream index for a content block event."""
|
||||
upstream_index = payload.get("index")
|
||||
if not isinstance(upstream_index, int):
|
||||
return None
|
||||
if upstream_index in state.dropped_indexes:
|
||||
return None
|
||||
mapped_index = state.index_map.get(upstream_index)
|
||||
if mapped_index is None and create:
|
||||
mapped_index = state.next_index
|
||||
state.index_map[upstream_index] = mapped_index
|
||||
state.next_index += 1
|
||||
return mapped_index
|
||||
|
||||
def _close_open_blocks_before(
|
||||
self, state: _SSEFilterState, upstream_index: int
|
||||
) -> str:
|
||||
"""Close overlapping upstream blocks before starting a new block."""
|
||||
events: list[str] = []
|
||||
for open_upstream_index in list(state.open_block_types):
|
||||
if open_upstream_index == upstream_index:
|
||||
continue
|
||||
mapped_index = state.index_map.get(open_upstream_index)
|
||||
if mapped_index is None:
|
||||
continue
|
||||
payload = {"type": "content_block_stop", "index": mapped_index}
|
||||
events.append(
|
||||
self._format_sse_event("content_block_stop", json.dumps(payload))
|
||||
)
|
||||
state.closed_indexes.add(open_upstream_index)
|
||||
state.open_block_types.pop(open_upstream_index, None)
|
||||
return "".join(events)
|
||||
|
||||
@staticmethod
|
||||
def _should_drop_block_type(block_type: Any, *, thinking_enabled: bool) -> bool:
|
||||
if not isinstance(block_type, str):
|
||||
return False
|
||||
if block_type.startswith("redacted_thinking"):
|
||||
return True
|
||||
return not thinking_enabled and "thinking" in block_type
|
||||
|
||||
def _transform_sse_payload(
|
||||
self,
|
||||
event: str,
|
||||
state: _SSEFilterState,
|
||||
*,
|
||||
thinking_enabled: bool,
|
||||
) -> str | None:
|
||||
"""Normalize OpenRouter SSE events and enforce local thinking policy."""
|
||||
event_name, data_text = self._parse_sse_event(event)
|
||||
if not event_name or not data_text:
|
||||
return event
|
||||
|
||||
try:
|
||||
payload = json.loads(data_text)
|
||||
except json.JSONDecodeError:
|
||||
return event
|
||||
|
||||
if event_name == "content_block_start":
|
||||
block = payload.get("content_block")
|
||||
if not isinstance(block, dict):
|
||||
return event
|
||||
block_type = block.get("type")
|
||||
upstream_index = payload.get("index")
|
||||
if self._should_drop_block_type(
|
||||
block_type, thinking_enabled=thinking_enabled
|
||||
):
|
||||
if isinstance(upstream_index, int):
|
||||
state.dropped_indexes.add(upstream_index)
|
||||
return None
|
||||
|
||||
mapped_index = self._remap_index(payload, state, create=True)
|
||||
if mapped_index is not None:
|
||||
payload["index"] = mapped_index
|
||||
if isinstance(upstream_index, int) and isinstance(block_type, str):
|
||||
prefix = self._close_open_blocks_before(state, upstream_index)
|
||||
state.open_block_types[upstream_index] = block_type
|
||||
return prefix + self._format_sse_event(
|
||||
event_name, json.dumps(payload)
|
||||
)
|
||||
return self._format_sse_event(event_name, json.dumps(payload))
|
||||
return None if not thinking_enabled else event
|
||||
|
||||
if event_name == "content_block_delta":
|
||||
delta = payload.get("delta")
|
||||
if not isinstance(delta, dict):
|
||||
return event
|
||||
delta_type = delta.get("type")
|
||||
if self._should_drop_block_type(
|
||||
delta_type, thinking_enabled=thinking_enabled
|
||||
):
|
||||
return None
|
||||
|
||||
mapped_index = self._remap_index(payload, state, create=False)
|
||||
if mapped_index is not None:
|
||||
payload["index"] = mapped_index
|
||||
return self._format_sse_event(event_name, json.dumps(payload))
|
||||
if payload.get("index") in state.dropped_indexes:
|
||||
return None
|
||||
if not thinking_enabled:
|
||||
return None
|
||||
|
||||
if event_name == "content_block_stop":
|
||||
upstream_index = payload.get("index")
|
||||
if (
|
||||
isinstance(upstream_index, int)
|
||||
and upstream_index in state.closed_indexes
|
||||
):
|
||||
state.closed_indexes.discard(upstream_index)
|
||||
return None
|
||||
mapped_index = self._remap_index(payload, state, create=False)
|
||||
if mapped_index is not None:
|
||||
payload["index"] = mapped_index
|
||||
if isinstance(upstream_index, int):
|
||||
state.open_block_types.pop(upstream_index, None)
|
||||
return self._format_sse_event(event_name, json.dumps(payload))
|
||||
if payload.get("index") in state.dropped_indexes:
|
||||
return None
|
||||
if not thinking_enabled:
|
||||
return None
|
||||
|
||||
return event
|
||||
|
||||
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
|
||||
"""Create per-stream state for thinking block filtering."""
|
||||
return _SSEFilterState()
|
||||
return NativeSseBlockPolicyState()
|
||||
|
||||
def _transform_stream_event(
|
||||
self,
|
||||
|
|
@ -226,23 +63,17 @@ class OpenRouterProvider(AnthropicMessagesTransport):
|
|||
thinking_enabled: bool,
|
||||
) -> str | None:
|
||||
"""Drop provider-specific terminal noise and hidden thinking events."""
|
||||
if isinstance(state, _SSEFilterState):
|
||||
event_name, data_text = self._parse_sse_event(event)
|
||||
if state.message_stopped or self._is_terminal_done_event(
|
||||
if isinstance(state, NativeSseBlockPolicyState):
|
||||
event_name, data_text = parse_native_sse_event(event)
|
||||
if state.message_stopped or is_terminal_openrouter_done_event(
|
||||
event_name, data_text
|
||||
):
|
||||
return None
|
||||
if event_name == "message_stop":
|
||||
state.message_stopped = True
|
||||
|
||||
if thinking_enabled:
|
||||
if isinstance(state, _SSEFilterState):
|
||||
return self._transform_sse_payload(
|
||||
event, state, thinking_enabled=thinking_enabled
|
||||
)
|
||||
return event
|
||||
if isinstance(state, _SSEFilterState):
|
||||
return self._transform_sse_payload(
|
||||
if isinstance(state, NativeSseBlockPolicyState):
|
||||
return transform_native_sse_block_event(
|
||||
event, state, thinking_enabled=thinking_enabled
|
||||
)
|
||||
return event
|
||||
|
|
@ -260,9 +91,10 @@ class OpenRouterProvider(AnthropicMessagesTransport):
|
|||
sent_any_event: bool,
|
||||
) -> Iterator[str]:
|
||||
"""Emit the Anthropic SSE error shape expected by Claude clients."""
|
||||
sse = SSEBuilder(f"msg_{uuid.uuid4()}", request.model, input_tokens)
|
||||
if not sent_any_event:
|
||||
yield sse.message_start()
|
||||
yield from sse.emit_error(error_message)
|
||||
yield sse.message_delta("end_turn", 1)
|
||||
yield sse.message_stop()
|
||||
yield from iter_provider_stream_error_sse_events(
|
||||
request=request,
|
||||
input_tokens=input_tokens,
|
||||
error_message=error_message,
|
||||
sent_any_event=sent_any_event,
|
||||
log_raw_sse_events=self._config.log_raw_sse_events,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,136 +2,18 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
|
||||
|
||||
_REQUEST_FIELDS = (
|
||||
"model",
|
||||
"messages",
|
||||
"system",
|
||||
"max_tokens",
|
||||
"stop_sequences",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"metadata",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"thinking",
|
||||
"extra_body",
|
||||
"original_model",
|
||||
"resolved_provider_model",
|
||||
from config.constants import (
|
||||
ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS as OPENROUTER_DEFAULT_MAX_TOKENS,
|
||||
)
|
||||
|
||||
_INTERNAL_FIELDS = {
|
||||
"thinking",
|
||||
"extra_body",
|
||||
"original_model",
|
||||
"resolved_provider_model",
|
||||
}
|
||||
_THINKING_HISTORY_BLOCK_TYPES = {"thinking", "redacted_thinking"}
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Convert Pydantic models and lightweight objects into JSON-ready values."""
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(exclude_none=True)
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: _serialize_value(item)
|
||||
for key, item in value.items()
|
||||
if item is not None
|
||||
}
|
||||
if isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray):
|
||||
return [_serialize_value(item) for item in value]
|
||||
if value is None or isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
if hasattr(value, "__dict__"):
|
||||
return {
|
||||
key: _serialize_value(item)
|
||||
for key, item in vars(value).items()
|
||||
if not key.startswith("_") and item is not None
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
def _dump_request_fields(request_data: Any) -> dict[str, Any]:
|
||||
"""Extract the public request fields forwarded to OpenRouter."""
|
||||
if isinstance(request_data, BaseModel):
|
||||
return request_data.model_dump(exclude_none=True)
|
||||
|
||||
dumped: dict[str, Any] = {}
|
||||
for field in _REQUEST_FIELDS:
|
||||
value = getattr(request_data, field, None)
|
||||
if value is not None:
|
||||
dumped[field] = _serialize_value(value)
|
||||
return dumped
|
||||
|
||||
|
||||
def _strip_unsigned_thinking_history(messages: Any) -> Any:
|
||||
"""Remove assistant thinking history blocks that OpenRouter cannot replay."""
|
||||
if not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
sanitized_messages: list[Any] = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
sanitized_messages.append(message)
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
sanitized_messages.append(message)
|
||||
continue
|
||||
|
||||
sanitized_content = [
|
||||
block
|
||||
for block in content
|
||||
if not (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") in _THINKING_HISTORY_BLOCK_TYPES
|
||||
and not isinstance(block.get("signature"), str)
|
||||
from core.anthropic.native_messages_request import (
|
||||
OpenRouterExtraBodyError,
|
||||
build_openrouter_native_request_body,
|
||||
)
|
||||
]
|
||||
|
||||
sanitized_message = dict(message)
|
||||
sanitized_message["content"] = sanitized_content or ""
|
||||
sanitized_messages.append(sanitized_message)
|
||||
|
||||
return sanitized_messages
|
||||
|
||||
|
||||
def _normalize_system_prompt(system: Any) -> Any:
|
||||
"""Flatten Claude SDK system blocks for OpenRouter's native endpoint."""
|
||||
if not isinstance(system, list):
|
||||
return system
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in system:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text_parts.append(block["text"])
|
||||
return "\n\n".join(text_parts).strip() if text_parts else system
|
||||
|
||||
|
||||
def _apply_reasoning(body: dict[str, Any], thinking_cfg: Any) -> None:
|
||||
"""Map Anthropic thinking controls onto OpenRouter reasoning controls."""
|
||||
reasoning = body.setdefault("reasoning", {"enabled": True})
|
||||
if not isinstance(reasoning, dict):
|
||||
return
|
||||
reasoning.setdefault("enabled", True)
|
||||
if not isinstance(thinking_cfg, dict):
|
||||
return
|
||||
budget_tokens = thinking_cfg.get("budget_tokens")
|
||||
if isinstance(budget_tokens, int):
|
||||
reasoning.setdefault("max_tokens", budget_tokens)
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
|
|
@ -142,27 +24,14 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
|||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
|
||||
dumped_request = _dump_request_fields(request_data)
|
||||
request_extra = dumped_request.pop("extra_body", None)
|
||||
thinking_cfg = dumped_request.get("thinking")
|
||||
body = {
|
||||
key: value
|
||||
for key, value in dumped_request.items()
|
||||
if key not in _INTERNAL_FIELDS
|
||||
}
|
||||
|
||||
if isinstance(request_extra, dict):
|
||||
body.update(request_extra)
|
||||
|
||||
body["messages"] = _strip_unsigned_thinking_history(body.get("messages"))
|
||||
if "system" in body:
|
||||
body["system"] = _normalize_system_prompt(body["system"])
|
||||
body["stream"] = True
|
||||
if body.get("max_tokens") is None:
|
||||
body["max_tokens"] = OPENROUTER_DEFAULT_MAX_TOKENS
|
||||
|
||||
if thinking_enabled:
|
||||
_apply_reasoning(body, thinking_cfg)
|
||||
try:
|
||||
body = build_openrouter_native_request_body(
|
||||
request_data,
|
||||
thinking_enabled=thinking_enabled,
|
||||
default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS,
|
||||
)
|
||||
except OpenRouterExtraBodyError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
|
||||
logger.debug(
|
||||
"OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
|
|
|
|||
|
|
@ -21,14 +21,40 @@ from core.anthropic import (
|
|||
SSEBuilder,
|
||||
ThinkTagParser,
|
||||
append_request_id,
|
||||
get_user_facing_error_message,
|
||||
map_stop_reason,
|
||||
)
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.error_mapping import map_error
|
||||
from providers.error_mapping import (
|
||||
map_error,
|
||||
user_visible_message_for_mapped_provider_error,
|
||||
)
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
|
||||
def _iter_heuristic_tool_use_sse(
|
||||
sse: SSEBuilder, tool_use: dict[str, Any]
|
||||
) -> Iterator[str]:
|
||||
"""Emit SSE for one heuristic tool_use block (closes open text/thinking first)."""
|
||||
if tool_use.get("name") == "Task" and isinstance(tool_use.get("input"), dict):
|
||||
task_input = tool_use["input"]
|
||||
if task_input.get("run_in_background") is not False:
|
||||
task_input["run_in_background"] = False
|
||||
yield from sse.close_content_blocks()
|
||||
block_idx = sse.blocks.allocate_index()
|
||||
yield sse.content_block_start(
|
||||
block_idx,
|
||||
"tool_use",
|
||||
id=tool_use["id"],
|
||||
name=tool_use["name"],
|
||||
)
|
||||
yield sse.content_block_delta(
|
||||
block_idx,
|
||||
"input_json_delta",
|
||||
json.dumps(tool_use["input"]),
|
||||
)
|
||||
yield sse.content_block_stop(block_idx)
|
||||
|
||||
|
||||
class OpenAIChatTransport(BaseProvider):
|
||||
"""Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, DeepSeek, …)."""
|
||||
|
||||
|
|
@ -113,6 +139,22 @@ class OpenAIChatTransport(BaseProvider):
|
|||
)
|
||||
return stream, retry_body
|
||||
|
||||
def _emit_tool_arg_delta(
|
||||
self, sse: SSEBuilder, tc_index: int, args: str
|
||||
) -> Iterator[str]:
|
||||
"""Emit one argument fragment for a started tool block (Task buffer or raw JSON)."""
|
||||
if not args:
|
||||
return
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if state is None:
|
||||
return
|
||||
if state.name == "Task":
|
||||
parsed = sse.blocks.buffer_task_args(tc_index, args)
|
||||
if parsed is not None:
|
||||
yield sse.emit_tool_delta(tc_index, json.dumps(parsed))
|
||||
return
|
||||
yield sse.emit_tool_delta(tc_index, args)
|
||||
|
||||
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
|
||||
"""Process a single tool call delta and yield SSE events."""
|
||||
tc_index = tc.get("index", 0)
|
||||
|
|
@ -121,34 +163,42 @@ class OpenAIChatTransport(BaseProvider):
|
|||
|
||||
fn_delta = tc.get("function", {})
|
||||
incoming_name = fn_delta.get("name")
|
||||
arguments = fn_delta.get("arguments", "")
|
||||
arguments = fn_delta.get("arguments", "") or ""
|
||||
|
||||
if tc.get("id") is not None:
|
||||
sse.blocks.set_stream_tool_id(tc_index, tc.get("id"))
|
||||
|
||||
if incoming_name is not None:
|
||||
sse.blocks.register_tool_name(tc_index, incoming_name)
|
||||
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if state is None or not state.started:
|
||||
name = state.name if state else ""
|
||||
if name or tc.get("id"):
|
||||
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
|
||||
yield sse.start_tool_block(tc_index, tool_id, name)
|
||||
resolved_id = (state.tool_id if state and state.tool_id else None) or tc.get(
|
||||
"id"
|
||||
)
|
||||
resolved_name = (state.name if state else "") or ""
|
||||
|
||||
args = arguments
|
||||
if args:
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if state is None or not state.started:
|
||||
tool_id = tc.get("id") or f"tool_{uuid.uuid4()}"
|
||||
name = (state.name if state else None) or "tool_call"
|
||||
yield sse.start_tool_block(tc_index, tool_id, name)
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if not state or not state.started:
|
||||
name_ok = bool((resolved_name or "").strip())
|
||||
if name_ok:
|
||||
tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}"
|
||||
display_name = (resolved_name or "").strip() or "tool_call"
|
||||
yield sse.start_tool_block(tc_index, tool_id, display_name)
|
||||
state = sse.blocks.tool_states[tc_index]
|
||||
if state.pre_start_args:
|
||||
pre = state.pre_start_args
|
||||
state.pre_start_args = ""
|
||||
yield from self._emit_tool_arg_delta(sse, tc_index, pre)
|
||||
|
||||
current_name = state.name if state else ""
|
||||
if current_name == "Task":
|
||||
parsed = sse.blocks.buffer_task_args(tc_index, args)
|
||||
if parsed is not None:
|
||||
yield sse.emit_tool_delta(tc_index, json.dumps(parsed))
|
||||
state = sse.blocks.tool_states.get(tc_index)
|
||||
if not arguments:
|
||||
return
|
||||
if state is None or not state.started:
|
||||
state = sse.blocks.ensure_tool_state(tc_index)
|
||||
if not (resolved_name or "").strip():
|
||||
state.pre_start_args += arguments
|
||||
return
|
||||
|
||||
yield sse.emit_tool_delta(tc_index, args)
|
||||
yield from self._emit_tool_arg_delta(sse, tc_index, arguments)
|
||||
|
||||
def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]:
|
||||
"""Emit buffered Task args as a single JSON delta (best-effort)."""
|
||||
|
|
@ -181,7 +231,12 @@ class OpenAIChatTransport(BaseProvider):
|
|||
"""Shared streaming implementation."""
|
||||
tag = self._provider_name
|
||||
message_id = f"msg_{uuid.uuid4()}"
|
||||
sse = SSEBuilder(message_id, request.model, input_tokens)
|
||||
sse = SSEBuilder(
|
||||
message_id,
|
||||
request.model,
|
||||
input_tokens,
|
||||
log_raw_events=self._config.log_raw_sse_events,
|
||||
)
|
||||
|
||||
body = self._build_request_body(request, thinking_enabled=thinking_enabled)
|
||||
thinking_enabled = self._is_thinking_enabled(request, thinking_enabled)
|
||||
|
|
@ -201,8 +256,6 @@ class OpenAIChatTransport(BaseProvider):
|
|||
heuristic_parser = HeuristicToolParser()
|
||||
finish_reason = None
|
||||
usage_info = None
|
||||
error_occurred = False
|
||||
error_message = ""
|
||||
|
||||
async with self._global_rate_limiter.concurrency_slot():
|
||||
try:
|
||||
|
|
@ -258,26 +311,10 @@ class OpenAIChatTransport(BaseProvider):
|
|||
yield sse.emit_text_delta(filtered_text)
|
||||
|
||||
for tool_use in detected_tools:
|
||||
for event in sse.close_content_blocks():
|
||||
yield event
|
||||
|
||||
block_idx = sse.blocks.allocate_index()
|
||||
if tool_use.get("name") == "Task" and isinstance(
|
||||
tool_use.get("input"), dict
|
||||
for event in _iter_heuristic_tool_use_sse(
|
||||
sse, tool_use
|
||||
):
|
||||
tool_use["input"]["run_in_background"] = False
|
||||
yield sse.content_block_start(
|
||||
block_idx,
|
||||
"tool_use",
|
||||
id=tool_use["id"],
|
||||
name=tool_use["name"],
|
||||
)
|
||||
yield sse.content_block_delta(
|
||||
block_idx,
|
||||
"input_json_delta",
|
||||
json.dumps(tool_use["input"]),
|
||||
)
|
||||
yield sse.content_block_stop(block_idx)
|
||||
yield event
|
||||
|
||||
# Handle native tool calls
|
||||
if delta.tool_calls:
|
||||
|
|
@ -298,17 +335,12 @@ class OpenAIChatTransport(BaseProvider):
|
|||
except asyncio.CancelledError, GeneratorExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||
self._log_stream_transport_error(tag, req_tag, e)
|
||||
mapped_e = map_error(e, rate_limiter=self._global_rate_limiter)
|
||||
error_occurred = True
|
||||
if getattr(mapped_e, "status_code", None) == 405:
|
||||
base_message = (
|
||||
f"Upstream provider {tag} rejected the request method "
|
||||
"or endpoint (HTTP 405)."
|
||||
)
|
||||
else:
|
||||
base_message = get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
base_message = user_visible_message_for_mapped_provider_error(
|
||||
mapped_e,
|
||||
provider_name=tag,
|
||||
read_timeout_s=self._config.http_read_timeout,
|
||||
)
|
||||
error_message = append_request_id(base_message, request_id)
|
||||
logger.info(
|
||||
|
|
@ -317,10 +349,13 @@ class OpenAIChatTransport(BaseProvider):
|
|||
type(e).__name__,
|
||||
req_tag,
|
||||
)
|
||||
for event in sse.close_content_blocks():
|
||||
for event in sse.close_all_blocks():
|
||||
yield event
|
||||
for event in sse.emit_error(error_message):
|
||||
yield event
|
||||
yield sse.message_delta("end_turn", 1)
|
||||
yield sse.message_stop()
|
||||
return
|
||||
|
||||
# Flush remaining content
|
||||
remaining = think_parser.flush()
|
||||
|
|
@ -338,32 +373,16 @@ class OpenAIChatTransport(BaseProvider):
|
|||
yield sse.emit_text_delta(remaining.content)
|
||||
|
||||
for tool_use in heuristic_parser.flush():
|
||||
for event in sse.close_content_blocks():
|
||||
for event in _iter_heuristic_tool_use_sse(sse, tool_use):
|
||||
yield event
|
||||
|
||||
block_idx = sse.blocks.allocate_index()
|
||||
yield sse.content_block_start(
|
||||
block_idx,
|
||||
"tool_use",
|
||||
id=tool_use["id"],
|
||||
name=tool_use["name"],
|
||||
has_started_tool = any(s.started for s in sse.blocks.tool_states.values())
|
||||
has_content_blocks = (
|
||||
sse.blocks.text_index != -1
|
||||
or sse.blocks.thinking_index != -1
|
||||
or has_started_tool
|
||||
)
|
||||
if tool_use.get("name") == "Task" and isinstance(
|
||||
tool_use.get("input"), dict
|
||||
):
|
||||
tool_use["input"]["run_in_background"] = False
|
||||
yield sse.content_block_delta(
|
||||
block_idx,
|
||||
"input_json_delta",
|
||||
json.dumps(tool_use["input"]),
|
||||
)
|
||||
yield sse.content_block_stop(block_idx)
|
||||
|
||||
if (
|
||||
not error_occurred
|
||||
and sse.blocks.text_index == -1
|
||||
and not sse.blocks.tool_states
|
||||
):
|
||||
if not has_content_blocks:
|
||||
for event in sse.ensure_text_block():
|
||||
yield event
|
||||
yield sse.emit_text_delta(" ")
|
||||
|
|
|
|||
|
|
@ -3,14 +3,16 @@
|
|||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from loguru import logger
|
||||
|
||||
from core.rate_limit import StrictSlidingWindowLimiter
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
|
@ -51,10 +53,10 @@ class GlobalRateLimiter:
|
|||
self._rate_limit = rate_limit
|
||||
self._rate_window = float(rate_window)
|
||||
self._max_concurrency = max_concurrency
|
||||
# Monotonic timestamps of the last granted slots.
|
||||
self._request_times: deque[float] = deque()
|
||||
self._proactive_limiter = StrictSlidingWindowLimiter(
|
||||
self._rate_limit, self._rate_window
|
||||
)
|
||||
self._blocked_until: float = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._concurrency_sem = asyncio.Semaphore(max_concurrency)
|
||||
self._initialized = True
|
||||
|
||||
|
|
@ -107,7 +109,6 @@ class GlobalRateLimiter:
|
|||
logger.info(
|
||||
"Rebuilding provider rate limiter for updated scope '{}'", scope
|
||||
)
|
||||
if scope not in cls._scoped_instances or existing:
|
||||
cls._scoped_instances[scope] = cls(
|
||||
rate_limit=desired_rate_limit,
|
||||
rate_window=desired_rate_window,
|
||||
|
|
@ -150,27 +151,7 @@ class GlobalRateLimiter:
|
|||
Guarantees: at most `self._rate_limit` acquisitions in any interval of length
|
||||
`self._rate_window` (seconds).
|
||||
"""
|
||||
while True:
|
||||
wait_time = 0.0
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._rate_window
|
||||
|
||||
while self._request_times and self._request_times[0] <= cutoff:
|
||||
self._request_times.popleft()
|
||||
|
||||
if len(self._request_times) < self._rate_limit:
|
||||
self._request_times.append(now)
|
||||
return
|
||||
|
||||
oldest = self._request_times[0]
|
||||
wait_time = max(0.0, (oldest + self._rate_window) - now)
|
||||
|
||||
# Sleep outside the lock so other tasks can continue to queue.
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
await self._proactive_limiter.acquire()
|
||||
|
||||
def set_blocked(self, seconds: float = 60) -> None:
|
||||
"""
|
||||
|
|
@ -263,6 +244,24 @@ class GlobalRateLimiter:
|
|||
)
|
||||
self.set_blocked(delay)
|
||||
await asyncio.sleep(delay)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 429:
|
||||
raise
|
||||
last_exc = e
|
||||
if attempt >= max_retries:
|
||||
logger.warning(
|
||||
f"HTTP 429 retry exhausted after {max_retries} retries"
|
||||
)
|
||||
break
|
||||
|
||||
delay = min(base_delay * (2**attempt), max_delay)
|
||||
delay += random.uniform(0, jitter)
|
||||
logger.warning(
|
||||
f"HTTP 429 from upstream, attempt {attempt + 1}/{max_retries + 1}. "
|
||||
f"Retrying in {delay:.1f}s..."
|
||||
)
|
||||
self.set_blocked(delay)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
|
|
|||
|
|
@ -3,108 +3,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
from config.provider_catalog import (
|
||||
PROVIDER_CATALOG,
|
||||
SUPPORTED_PROVIDER_IDS,
|
||||
ProviderDescriptor,
|
||||
)
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.defaults import (
|
||||
DEEPSEEK_DEFAULT_BASE,
|
||||
LLAMACPP_DEFAULT_BASE,
|
||||
LMSTUDIO_DEFAULT_BASE,
|
||||
NVIDIA_NIM_DEFAULT_BASE,
|
||||
OLLAMA_DEFAULT_BASE,
|
||||
OPENROUTER_DEFAULT_BASE,
|
||||
)
|
||||
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||
|
||||
TransportType = Literal["openai_chat", "anthropic_messages"]
|
||||
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProviderDescriptor:
|
||||
"""Metadata for building :class:`ProviderConfig` and factory wiring."""
|
||||
|
||||
provider_id: str
|
||||
transport_type: TransportType
|
||||
capabilities: tuple[str, ...]
|
||||
credential_env: str | None = None
|
||||
credential_url: str | None = None
|
||||
# If set, read API key from this attribute on ``Settings`` (e.g. nvidia_nim_api_key).
|
||||
credential_attr: str | None = None
|
||||
# If set, use this fixed key for local adapters (e.g. lm-studio, llamacpp).
|
||||
static_credential: str | None = None
|
||||
default_base_url: str | None = None
|
||||
base_url_attr: str | None = None
|
||||
proxy_attr: str | None = None
|
||||
|
||||
|
||||
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
|
||||
"nvidia_nim": ProviderDescriptor(
|
||||
provider_id="nvidia_nim",
|
||||
transport_type="openai_chat",
|
||||
credential_env="NVIDIA_NIM_API_KEY",
|
||||
credential_url="https://build.nvidia.com/settings/api-keys",
|
||||
credential_attr="nvidia_nim_api_key",
|
||||
default_base_url=NVIDIA_NIM_DEFAULT_BASE,
|
||||
proxy_attr="nvidia_nim_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
),
|
||||
"open_router": ProviderDescriptor(
|
||||
provider_id="open_router",
|
||||
transport_type="anthropic_messages",
|
||||
credential_env="OPENROUTER_API_KEY",
|
||||
credential_url="https://openrouter.ai/keys",
|
||||
credential_attr="open_router_api_key",
|
||||
default_base_url=OPENROUTER_DEFAULT_BASE,
|
||||
proxy_attr="open_router_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "native_anthropic"),
|
||||
),
|
||||
"deepseek": ProviderDescriptor(
|
||||
provider_id="deepseek",
|
||||
transport_type="openai_chat",
|
||||
credential_env="DEEPSEEK_API_KEY",
|
||||
credential_url="https://platform.deepseek.com/api_keys",
|
||||
credential_attr="deepseek_api_key",
|
||||
default_base_url=DEEPSEEK_DEFAULT_BASE,
|
||||
capabilities=("chat", "streaming", "thinking"),
|
||||
),
|
||||
"lmstudio": ProviderDescriptor(
|
||||
provider_id="lmstudio",
|
||||
transport_type="anthropic_messages",
|
||||
static_credential="lm-studio",
|
||||
default_base_url=LMSTUDIO_DEFAULT_BASE,
|
||||
base_url_attr="lm_studio_base_url",
|
||||
proxy_attr="lmstudio_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||
),
|
||||
"llamacpp": ProviderDescriptor(
|
||||
provider_id="llamacpp",
|
||||
transport_type="anthropic_messages",
|
||||
static_credential="llamacpp",
|
||||
default_base_url=LLAMACPP_DEFAULT_BASE,
|
||||
base_url_attr="llamacpp_base_url",
|
||||
proxy_attr="llamacpp_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "native_anthropic", "local"),
|
||||
),
|
||||
"ollama": ProviderDescriptor(
|
||||
provider_id="ollama",
|
||||
transport_type="anthropic_messages",
|
||||
static_credential="ollama",
|
||||
default_base_url=OLLAMA_DEFAULT_BASE,
|
||||
base_url_attr="ollama_base_url",
|
||||
capabilities=(
|
||||
"chat",
|
||||
"streaming",
|
||||
"tools",
|
||||
"thinking",
|
||||
"native_anthropic",
|
||||
"local",
|
||||
),
|
||||
),
|
||||
}
|
||||
# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
|
||||
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
|
||||
|
||||
|
||||
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
|
|
@ -119,25 +31,25 @@ def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProv
|
|||
return OpenRouterProvider(config)
|
||||
|
||||
|
||||
def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
|
||||
return DeepSeekProvider(config)
|
||||
|
||||
|
||||
def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
return LMStudioProvider(config)
|
||||
|
||||
|
||||
def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
|
||||
return LlamaCppProvider(config)
|
||||
|
||||
|
||||
def _create_ollama(config: ProviderConfig, settings: Settings) -> BaseProvider:
|
||||
def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
from providers.ollama import OllamaProvider
|
||||
|
||||
return OllamaProvider(config)
|
||||
|
|
@ -208,6 +120,8 @@ def build_provider_config(
|
|||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_model_thinking,
|
||||
proxy=proxy,
|
||||
log_raw_sse_events=settings.log_raw_sse_events,
|
||||
log_api_error_tracebacks=settings.log_api_error_tracebacks,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
"""
|
||||
Claude Code Proxy - Entry Point
|
||||
|
||||
Minimal entry point that imports the app from the api module.
|
||||
Minimal entry point that builds the ASGI app via :func:`api.app.create_app`.
|
||||
Run with: uv run uvicorn server:app --host 0.0.0.0 --port 8082 --timeout-graceful-shutdown 5
|
||||
"""
|
||||
|
||||
from api.app import app, create_app
|
||||
from api.app import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
__all__ = ["app", "create_app"]
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
|
|||
"providers.registry.ProviderRegistry",
|
||||
"provider id and Settings",
|
||||
"configured BaseProvider instance",
|
||||
"503 for missing credentials; ValueError for unknown provider",
|
||||
"503 for missing credentials; invalid_request_error for unknown provider",
|
||||
("tests/api/test_dependencies.py", "tests/providers/test_registry.py"),
|
||||
("test_configured_provider_models_stream_successfully",),
|
||||
),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,14 @@ import httpx
|
|||
import pytest
|
||||
|
||||
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
from core.anthropic.stream_contracts import (
|
||||
SSEEvent,
|
||||
assert_anthropic_stream_contract,
|
||||
event_index,
|
||||
has_tool_use,
|
||||
parse_sse_lines,
|
||||
text_content,
|
||||
)
|
||||
from messaging.handler import ClaudeMessageHandler
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.platforms.base import MessagingPlatform
|
||||
|
|
@ -26,13 +34,6 @@ from messaging.session import SessionStore
|
|||
from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers
|
||||
from smoke.lib.server import RunningServer, start_server
|
||||
from smoke.lib.skips import fail_missing_env
|
||||
from smoke.lib.sse import (
|
||||
SSEEvent,
|
||||
assert_anthropic_stream_contract,
|
||||
has_tool_use,
|
||||
parse_sse_lines,
|
||||
text_content,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
|
@ -590,14 +591,14 @@ def assistant_content_from_events(events: list[SSEEvent]) -> list[dict[str, Any]
|
|||
block_order: list[int] = []
|
||||
for event in events:
|
||||
if event.event == "content_block_start":
|
||||
index = _event_index(event)
|
||||
index = event_index(event)
|
||||
block = event.data.get("content_block", {})
|
||||
if isinstance(block, dict):
|
||||
blocks[index] = dict(block)
|
||||
block_order.append(index)
|
||||
continue
|
||||
if event.event == "content_block_delta":
|
||||
index = _event_index(event)
|
||||
index = event_index(event)
|
||||
block = blocks.get(index)
|
||||
delta = event.data.get("delta", {})
|
||||
if not isinstance(block, dict) or not isinstance(delta, dict):
|
||||
|
|
@ -663,12 +664,6 @@ def default_cli_events(session_id: str) -> list[dict[str, Any]]:
|
|||
]
|
||||
|
||||
|
||||
def _event_index(event: SSEEvent) -> int:
|
||||
value = event.data.get("index")
|
||||
assert isinstance(value, int), event.data
|
||||
return value
|
||||
|
||||
|
||||
def assert_product_stream(events: list[SSEEvent]) -> None:
|
||||
assert_anthropic_stream_contract(events)
|
||||
assert text_content(events).strip() or has_tool_use(events), (
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ from typing import Any
|
|||
|
||||
import httpx
|
||||
|
||||
from core.anthropic.stream_contracts import SSEEvent, parse_sse_lines
|
||||
|
||||
from .config import SmokeConfig, auth_headers, redacted
|
||||
from .server import RunningServer
|
||||
from .sse import SSEEvent, parse_sse_lines
|
||||
|
||||
|
||||
def message_payload(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
from smoke.lib.sse import SSEEvent
|
||||
from core.anthropic.stream_contracts import SSEEvent
|
||||
|
||||
UPSTREAM_UNAVAILABLE_MARKERS = (
|
||||
"connection refused",
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
"""SSE parsing and Anthropic stream assertions for smoke tests.
|
||||
|
||||
Canonical implementation lives in :mod:`core.anthropic.stream_contracts`; this
|
||||
module re-exports it for smoke and historical import paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.anthropic.stream_contracts import (
|
||||
SSEEvent,
|
||||
assert_anthropic_stream_contract,
|
||||
event_names,
|
||||
has_tool_use,
|
||||
parse_sse_lines,
|
||||
parse_sse_text,
|
||||
text_content,
|
||||
thinking_content,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SSEEvent",
|
||||
"assert_anthropic_stream_contract",
|
||||
"event_names",
|
||||
"has_tool_use",
|
||||
"parse_sse_lines",
|
||||
"parse_sse_text",
|
||||
"text_content",
|
||||
"thinking_content",
|
||||
]
|
||||
|
|
@ -5,6 +5,10 @@ import time
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.anthropic.stream_contracts import (
|
||||
assert_anthropic_stream_contract,
|
||||
text_content,
|
||||
)
|
||||
from smoke.lib.config import SmokeConfig, auth_headers
|
||||
from smoke.lib.http import collect_message_stream, message_payload
|
||||
from smoke.lib.server import start_server
|
||||
|
|
@ -12,7 +16,6 @@ from smoke.lib.skips import (
|
|||
skip_if_upstream_unavailable_events,
|
||||
skip_if_upstream_unavailable_exception,
|
||||
)
|
||||
from smoke.lib.sse import assert_anthropic_stream_contract, text_content
|
||||
|
||||
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")]
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,14 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from core.anthropic.stream_contracts import (
|
||||
assert_anthropic_stream_contract,
|
||||
has_tool_use,
|
||||
)
|
||||
from smoke.lib.config import SmokeConfig
|
||||
from smoke.lib.http import collect_message_stream, message_payload
|
||||
from smoke.lib.server import start_server
|
||||
from smoke.lib.skips import skip_if_upstream_unavailable_events
|
||||
from smoke.lib.sse import assert_anthropic_stream_contract, has_tool_use
|
||||
|
||||
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("tools")]
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ def test_voice_transcription_backend_when_explicitly_enabled(
|
|||
wav_path = tmp_path / "smoke-tone.wav"
|
||||
_write_tone_wav(wav_path)
|
||||
try:
|
||||
text = transcribe_audio(
|
||||
wav_path,
|
||||
"audio/wav",
|
||||
whisper_model=smoke_config.settings.whisper_model,
|
||||
whisper_device=smoke_config.settings.whisper_device,
|
||||
)
|
||||
t_kw: dict[str, str] = {
|
||||
"whisper_model": smoke_config.settings.whisper_model,
|
||||
"whisper_device": smoke_config.settings.whisper_device,
|
||||
}
|
||||
if smoke_config.settings.whisper_device == "nvidia_nim":
|
||||
t_kw["nvidia_nim_api_key"] = smoke_config.settings.nvidia_nim_api_key
|
||||
text = transcribe_audio(wav_path, "audio/wav", **t_kw)
|
||||
except ImportError as exc:
|
||||
pytest.skip(str(exc))
|
||||
assert isinstance(text, str)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@ from __future__ import annotations
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.anthropic.stream_contracts import (
|
||||
assert_anthropic_stream_contract,
|
||||
parse_sse_lines,
|
||||
text_content,
|
||||
)
|
||||
from smoke.lib.config import ProviderModel, SmokeConfig, auth_headers
|
||||
from smoke.lib.e2e import (
|
||||
ConversationDriver,
|
||||
|
|
@ -16,11 +21,6 @@ from smoke.lib.skips import (
|
|||
skip_if_upstream_unavailable_events,
|
||||
skip_if_upstream_unavailable_exception,
|
||||
)
|
||||
from smoke.lib.sse import (
|
||||
assert_anthropic_stream_contract,
|
||||
parse_sse_lines,
|
||||
text_content,
|
||||
)
|
||||
|
||||
pytestmark = [pytest.mark.live, pytest.mark.smoke_target("providers")]
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ def test_voice_nim_backend_e2e(smoke_config: SmokeConfig, tmp_path: Path) -> Non
|
|||
"audio/wav",
|
||||
whisper_model=smoke_config.settings.whisper_model,
|
||||
whisper_device="nvidia_nim",
|
||||
nvidia_nim_api_key=smoke_config.settings.nvidia_nim_api_key,
|
||||
)
|
||||
|
||||
assert isinstance(text, str)
|
||||
|
|
|
|||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Test suite package."""
|
||||
184
tests/api/test_anthropic_request_passthrough.py
Normal file
184
tests/api/test_anthropic_request_passthrough.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
"""Pydantic passthrough of Anthropic protocol fields (e.g. ``cache_control``)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from api.models.anthropic import (
|
||||
ContentBlockServerToolUse,
|
||||
ContentBlockText,
|
||||
ContentBlockWebSearchToolResult,
|
||||
Message,
|
||||
MessagesRequest,
|
||||
SystemContent,
|
||||
Tool,
|
||||
)
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from core.anthropic.native_messages_request import (
|
||||
build_base_native_anthropic_request_body,
|
||||
)
|
||||
|
||||
|
||||
def test_cache_control_preserved_on_parsed_user_text_system_and_tool() -> None:
|
||||
raw = {
|
||||
"model": "m",
|
||||
"max_tokens": 20,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"system": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "be brief",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "alpha",
|
||||
"input_schema": {"type": "object"},
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
}
|
||||
req = MessagesRequest.model_validate(raw)
|
||||
block = req.messages[0].content[0]
|
||||
assert isinstance(block, ContentBlockText)
|
||||
assert block.model_dump()["cache_control"] == {"type": "ephemeral"}
|
||||
s0 = req.system[0] if isinstance(req.system, list) else None
|
||||
assert isinstance(s0, SystemContent)
|
||||
assert s0.model_dump()["cache_control"] == {"type": "ephemeral"}
|
||||
t0 = req.tools[0] if req.tools else None
|
||||
assert isinstance(t0, Tool)
|
||||
assert t0.model_dump()["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
|
||||
def test_build_base_native_body_includes_cache_control() -> None:
|
||||
req = MessagesRequest(
|
||||
model="m",
|
||||
max_tokens=20,
|
||||
messages=[
|
||||
Message(
|
||||
role="user",
|
||||
content=[
|
||||
ContentBlockText.model_validate(
|
||||
{
|
||||
"type": "text",
|
||||
"text": "x",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
system=[
|
||||
SystemContent.model_validate(
|
||||
{
|
||||
"type": "text",
|
||||
"text": "s",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
)
|
||||
],
|
||||
tools=[
|
||||
Tool.model_validate(
|
||||
{
|
||||
"name": "n",
|
||||
"input_schema": {"type": "object"},
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
body = build_base_native_anthropic_request_body(
|
||||
req,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=False,
|
||||
)
|
||||
assert body["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
assert body["system"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
assert body["tools"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
|
||||
def test_pydantic_discriminator_still_distinguishes_blocks() -> None:
|
||||
m = Message.model_validate(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "a", "z": 1}],
|
||||
}
|
||||
)
|
||||
b = m.content[0]
|
||||
assert isinstance(b, ContentBlockText)
|
||||
assert b.model_dump()["z"] == 1
|
||||
|
||||
|
||||
def test_server_tool_assistant_blocks_round_trip_in_native_body() -> None:
|
||||
"""Local server-tool responses must parse as valid history for a follow-up request."""
|
||||
raw = {
|
||||
"model": "m",
|
||||
"max_tokens": 20,
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "server_tool_use",
|
||||
"id": "srvtoolu_1",
|
||||
"name": "web_search",
|
||||
"input": {"query": "q"},
|
||||
},
|
||||
{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srvtoolu_1",
|
||||
"content": [
|
||||
{
|
||||
"type": "web_search_result",
|
||||
"title": "T",
|
||||
"url": "https://example.com",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"mcp_servers": [{"type": "url", "url": "https://example.com/mcp"}],
|
||||
}
|
||||
req = MessagesRequest.model_validate(raw)
|
||||
assert len(req.messages) == 1
|
||||
blocks = req.messages[0].content
|
||||
assert isinstance(blocks, list)
|
||||
assert isinstance(blocks[0], ContentBlockServerToolUse)
|
||||
assert isinstance(blocks[1], ContentBlockWebSearchToolResult)
|
||||
body = build_base_native_anthropic_request_body(
|
||||
req,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=False,
|
||||
)
|
||||
assert body["mcp_servers"][0]["type"] == "url"
|
||||
assert body["messages"][0]["content"][0]["type"] == "server_tool_use"
|
||||
assert body["messages"][0]["content"][1]["type"] == "web_search_tool_result"
|
||||
|
||||
|
||||
def test_native_body_preserves_context_and_output_config() -> None:
|
||||
raw = {
|
||||
"model": "m",
|
||||
"max_tokens": 20,
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"context_management": {"edits": [{"type": "clear"}]},
|
||||
"output_config": {"some": "hint"},
|
||||
}
|
||||
req = MessagesRequest.model_validate(raw)
|
||||
body = build_base_native_anthropic_request_body(
|
||||
req,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=False,
|
||||
)
|
||||
assert body["context_management"] == raw["context_management"]
|
||||
assert body["output_config"] == raw["output_config"]
|
||||
|
|
@ -3,9 +3,11 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.app import app
|
||||
from api.app import create_app
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
app = create_app()
|
||||
|
||||
# Mock provider
|
||||
mock_provider = MagicMock(spec=NvidiaNimProvider)
|
||||
|
||||
|
|
@ -171,7 +173,7 @@ def test_generic_exception_returns_500(client: TestClient):
|
|||
|
||||
|
||||
def test_generic_exception_with_status_code(client: TestClient):
|
||||
"""Generic exception with status_code attribute uses that status (getattr fallback)."""
|
||||
"""Unexpected errors always map to HTTP 500 (ignore ad-hoc status_code attrs)."""
|
||||
|
||||
class ExceptionWithStatus(RuntimeError):
|
||||
def __init__(self, msg: str, status_code: int = 500):
|
||||
|
|
@ -191,7 +193,7 @@ def test_generic_exception_with_status_code(client: TestClient):
|
|||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 502
|
||||
assert response.status_code == 500
|
||||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,15 @@ _RUNTIME_EXTRAS = {
|
|||
"nvidia_nim_api_key": "",
|
||||
"claude_cli_bin": "claude",
|
||||
"uses_process_anthropic_auth_token": lambda: False,
|
||||
"messaging_rate_limit": 1,
|
||||
"messaging_rate_window": 1.0,
|
||||
"max_message_log_entries_per_chat": None,
|
||||
"debug_platform_edits": False,
|
||||
"debug_subagent_stack": False,
|
||||
"log_api_error_tracebacks": False,
|
||||
"log_raw_messaging_content": False,
|
||||
"log_raw_cli_diagnostics": False,
|
||||
"log_messaging_error_details": False,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -86,6 +95,47 @@ def test_create_app_provider_error_handler_returns_anthropic_format():
|
|||
assert body["error"]["type"] == "authentication_error"
|
||||
|
||||
|
||||
def test_create_app_provider_error_default_logs_exclude_provider_message():
|
||||
"""Provider errors must not log exc.message by default."""
|
||||
from api.app import create_app
|
||||
from providers.exceptions import AuthenticationError
|
||||
|
||||
app = create_app()
|
||||
secret = "provider-upstream-secret-detail"
|
||||
|
||||
@app.get("/raise_provider_secret")
|
||||
async def _raise():
|
||||
raise AuthenticationError(secret)
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token=None,
|
||||
allowed_telegram_user_id=None,
|
||||
discord_bot_token=None,
|
||||
allowed_discord_channels=None,
|
||||
allowed_dir="",
|
||||
claude_workspace="./agent_workspace",
|
||||
host="127.0.0.1",
|
||||
port=8082,
|
||||
log_file="server.log",
|
||||
log_api_error_tracebacks=False,
|
||||
)
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
|
||||
patch.object(api_app_mod.logger, "error") as log_err,
|
||||
):
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/raise_provider_secret")
|
||||
assert resp.status_code == 401
|
||||
|
||||
blob = " ".join(str(a) for c in log_err.call_args_list for a in c.args)
|
||||
blob += repr([c.kwargs for c in log_err.call_args_list])
|
||||
assert secret not in blob
|
||||
assert "authentication_error" in blob
|
||||
|
||||
|
||||
def test_create_app_general_exception_handler_returns_500():
|
||||
from api.app import create_app
|
||||
|
||||
|
|
@ -120,6 +170,50 @@ def test_create_app_general_exception_handler_returns_500():
|
|||
assert body["error"]["type"] == "api_error"
|
||||
|
||||
|
||||
def test_create_app_general_exception_default_logs_exclude_exception_message():
|
||||
"""Unhandled errors must not log exception text by default (may echo user content)."""
|
||||
from api.app import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
secret = "user-provided-secret-token-xyzzy"
|
||||
|
||||
@app.get("/raise_secret")
|
||||
async def _raise_secret():
|
||||
raise ValueError(secret)
|
||||
|
||||
api_app_mod = importlib.import_module("api.app")
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token=None,
|
||||
allowed_telegram_user_id=None,
|
||||
discord_bot_token=None,
|
||||
allowed_discord_channels=None,
|
||||
allowed_dir="",
|
||||
claude_workspace="./agent_workspace",
|
||||
host="127.0.0.1",
|
||||
port=8082,
|
||||
log_file="server.log",
|
||||
log_api_error_tracebacks=False,
|
||||
)
|
||||
with (
|
||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(ProviderRegistry, "cleanup", new=AsyncMock()),
|
||||
patch.object(api_app_mod.logger, "error") as log_err,
|
||||
):
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/raise_secret")
|
||||
assert resp.status_code == 500
|
||||
|
||||
flattened: list[str] = []
|
||||
for call in log_err.call_args_list:
|
||||
flattened.extend(str(arg) for arg in call.args)
|
||||
flattened.append(repr(call.kwargs))
|
||||
blob = " ".join(flattened)
|
||||
assert secret not in blob
|
||||
assert "ValueError" in blob
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messaging_enabled", [True, False], ids=["with_platform", "no_platform"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ from unittest.mock import patch
|
|||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.app import app
|
||||
from api.app import create_app
|
||||
from api.dependencies import get_settings
|
||||
from config.settings import Settings
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
def test_anthropic_auth_token_required_and_accepts_x_api_key():
|
||||
client = TestClient(app)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from api.dependencies import (
|
|||
)
|
||||
from config.nim import NimSettings
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
from providers.exceptions import UnknownProviderTypeError
|
||||
from providers.exceptions import ServiceUnavailableError, UnknownProviderTypeError
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from providers.ollama import OllamaProvider
|
||||
|
|
@ -40,7 +40,7 @@ def _make_mock_settings(**overrides):
|
|||
mock.nim = NimSettings()
|
||||
mock.http_read_timeout = 300.0
|
||||
mock.http_write_timeout = 10.0
|
||||
mock.http_connect_timeout = 2.0
|
||||
mock.http_connect_timeout = 10.0
|
||||
mock.enable_model_thinking = True
|
||||
for key, value in overrides.items():
|
||||
setattr(mock, key, value)
|
||||
|
|
@ -434,18 +434,17 @@ def test_resolve_provider_per_app_uses_separate_registries() -> None:
|
|||
assert p1 is not p2
|
||||
|
||||
|
||||
def test_resolve_provider_lazily_installs_registry() -> None:
|
||||
"""If app has no provider_registry, one is created on app.state."""
|
||||
def test_resolve_provider_missing_registry_raises_service_unavailable() -> None:
|
||||
"""HTTP apps must install app.state.provider_registry (e.g. via AppRuntime)."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings()
|
||||
settings = _make_mock_settings()
|
||||
app = SimpleNamespace(state=State())
|
||||
assert getattr(app.state, "provider_registry", None) is None
|
||||
with pytest.raises(
|
||||
ServiceUnavailableError, match="Provider registry is not configured"
|
||||
):
|
||||
resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
|
||||
reg = app.state.provider_registry
|
||||
assert reg is not None
|
||||
p2 = resolve_provider("nvidia_nim", app=cast(Starlette, app), settings=settings)
|
||||
assert p2 is reg.get("nvidia_nim", settings) # same registry instance
|
||||
|
||||
|
||||
def test_resolve_provider_unrelated_value_error_is_not_unknown_provider_log() -> None:
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.app import app
|
||||
from api.app import create_app
|
||||
from api.dependencies import get_settings
|
||||
from config.settings import Settings
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
|
|
@ -103,6 +105,17 @@ def test_create_message_empty_messages_returns_400(client):
|
|||
assert "cannot be empty" in data.get("error", {}).get("message", "")
|
||||
|
||||
|
||||
def test_count_tokens_empty_messages_returns_400(client):
|
||||
"""POST /v1/messages/count_tokens with messages: [] matches messages validation."""
|
||||
payload = {"model": "claude-3-sonnet", "messages": []}
|
||||
response = client.post("/v1/messages/count_tokens", json=payload)
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data.get("type") == "error"
|
||||
assert data.get("error", {}).get("type") == "invalid_request_error"
|
||||
assert "cannot be empty" in data.get("error", {}).get("message", "")
|
||||
|
||||
|
||||
def test_count_tokens_endpoint(client):
|
||||
payload = {
|
||||
"model": "claude-3-sonnet",
|
||||
|
|
|
|||
70
tests/api/test_runtime_safe_logging.py
Normal file
70
tests/api/test_runtime_safe_logging.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Tests for safe default logging in :mod:`api.runtime`."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.api.test_app_lifespan_and_errors import _app_settings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_start_failure_default_logs_exclude_traceback(caplog):
|
||||
api_runtime_mod = importlib.import_module("api.runtime")
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="t",
|
||||
allowed_telegram_user_id="1",
|
||||
discord_bot_token=None,
|
||||
allowed_discord_channels=None,
|
||||
allowed_dir="",
|
||||
claude_workspace="./agent_workspace",
|
||||
host="127.0.0.1",
|
||||
port=8082,
|
||||
log_file="server.log",
|
||||
log_api_error_tracebacks=False,
|
||||
)
|
||||
runtime = api_runtime_mod.AppRuntime(app=MagicMock(), settings=settings)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
side_effect=RuntimeError("SECRET_RUNTIME_DETAIL"),
|
||||
),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
await runtime._start_messaging_if_configured()
|
||||
|
||||
blob = " | ".join(r.getMessage() for r in caplog.records)
|
||||
assert "SECRET_RUNTIME_DETAIL" not in blob
|
||||
assert "exc_type=RuntimeError" in blob
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_best_effort_default_logs_exclude_exception_text(caplog):
|
||||
api_runtime_mod = importlib.import_module("api.runtime")
|
||||
|
||||
async def boom():
|
||||
raise ValueError("SECRET_SHUTDOWN")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=False)
|
||||
|
||||
blob = " | ".join(r.getMessage() for r in caplog.records)
|
||||
assert "SECRET_SHUTDOWN" not in blob
|
||||
assert "exc_type=ValueError" in blob
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_best_effort_verbose_includes_exception_text(caplog):
|
||||
api_runtime_mod = importlib.import_module("api.runtime")
|
||||
|
||||
async def boom():
|
||||
raise ValueError("VISIBLE_SHUTDOWN")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await api_runtime_mod.best_effort("test_step", boom(), log_verbose_errors=True)
|
||||
|
||||
blob = " | ".join(r.getMessage() for r in caplog.records)
|
||||
assert "VISIBLE_SHUTDOWN" in blob
|
||||
208
tests/api/test_safe_logging.py
Normal file
208
tests/api/test_safe_logging.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
"""Tests that API and SSE logging avoid raw sensitive payloads by default."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api import services as services_mod
|
||||
from api.models.anthropic import Message, MessagesRequest
|
||||
from api.services import ClaudeProxyService
|
||||
from config.settings import Settings
|
||||
from core.anthropic.sse import SSEBuilder
|
||||
|
||||
|
||||
def test_create_message_skips_full_payload_debug_log_by_default():
|
||||
settings = Settings()
|
||||
assert settings.log_raw_api_payloads is False
|
||||
mock_provider = MagicMock()
|
||||
|
||||
async def fake_stream(*_a, **_kw):
|
||||
yield "event: ping\ndata: {}\n\n"
|
||||
|
||||
mock_provider.stream_response = fake_stream
|
||||
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
|
||||
|
||||
request = MessagesRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=10,
|
||||
messages=[Message(role="user", content="secret-user-text")],
|
||||
)
|
||||
|
||||
with patch.object(services_mod.logger, "debug") as mock_debug:
|
||||
service.create_message(request)
|
||||
|
||||
full_payload_calls = [
|
||||
c
|
||||
for c in mock_debug.call_args_list
|
||||
if c.args and str(c.args[0]) == "FULL_PAYLOAD [{}]: {}"
|
||||
]
|
||||
assert not full_payload_calls
|
||||
|
||||
|
||||
def test_create_message_logs_full_payload_when_opt_in():
|
||||
settings = Settings()
|
||||
settings.log_raw_api_payloads = True
|
||||
mock_provider = MagicMock()
|
||||
|
||||
async def fake_stream(*_a, **_kw):
|
||||
yield "event: ping\ndata: {}\n\n"
|
||||
|
||||
mock_provider.stream_response = fake_stream
|
||||
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
|
||||
request = MessagesRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=10,
|
||||
messages=[Message(role="user", content="visible")],
|
||||
)
|
||||
|
||||
with patch.object(services_mod.logger, "debug") as mock_debug:
|
||||
service.create_message(request)
|
||||
|
||||
keys = [c.args[0] for c in mock_debug.call_args_list if c.args]
|
||||
assert any(k == "FULL_PAYLOAD [{}]: {}" for k in keys)
|
||||
|
||||
|
||||
def test_sse_builder_default_debug_has_no_serialized_json_content():
|
||||
with patch("core.anthropic.sse.logger.debug") as mock_debug:
|
||||
sse = SSEBuilder("msg_x", "m", 1, log_raw_events=False)
|
||||
sse.message_start()
|
||||
|
||||
assert mock_debug.call_count == 1
|
||||
message = str(mock_debug.call_args)
|
||||
assert "serialized_bytes=" in message
|
||||
assert "role" not in message
|
||||
assert "assistant" not in message
|
||||
|
||||
|
||||
def test_sse_builder_raw_logging_includes_event_body_when_enabled():
|
||||
with patch("core.anthropic.sse.logger.debug") as mock_debug:
|
||||
sse = SSEBuilder("msg_x", "m", 1, log_raw_events=True)
|
||||
sse.message_start()
|
||||
|
||||
assert mock_debug.call_count == 1
|
||||
message = str(mock_debug.call_args)
|
||||
assert "message_start" in message
|
||||
assert "role" in message
|
||||
|
||||
|
||||
def _flatten_log_calls(mock_log) -> str:
|
||||
parts: list[str] = []
|
||||
for call in mock_log.call_args_list:
|
||||
parts.extend(str(arg) for arg in call.args)
|
||||
parts.append(repr(call.kwargs))
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def test_create_message_unexpected_error_default_logs_exclude_exception_text():
|
||||
settings = Settings()
|
||||
assert settings.log_api_error_tracebacks is False
|
||||
secret = "upstream-secret-token-abc"
|
||||
|
||||
mock_provider = MagicMock()
|
||||
|
||||
def stream_boom(*_a, **_kw):
|
||||
raise RuntimeError(secret)
|
||||
|
||||
mock_provider.stream_response = stream_boom
|
||||
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
|
||||
request = MessagesRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=10,
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(services_mod.logger, "error") as log_err,
|
||||
pytest.raises(HTTPException),
|
||||
):
|
||||
service.create_message(request)
|
||||
|
||||
blob = _flatten_log_calls(log_err)
|
||||
assert secret not in blob
|
||||
assert "RuntimeError" in blob
|
||||
|
||||
|
||||
def test_create_message_unexpected_error_always_returns_500():
|
||||
"""Non-provider failures must not leak arbitrary status_code attributes."""
|
||||
|
||||
class WeirdError(Exception):
|
||||
status_code = 418
|
||||
|
||||
settings = Settings()
|
||||
mock_provider = MagicMock()
|
||||
|
||||
def stream_boom(*_a, **_kw):
|
||||
raise WeirdError("no")
|
||||
|
||||
mock_provider.stream_response = stream_boom
|
||||
service = ClaudeProxyService(settings, provider_getter=lambda _: mock_provider)
|
||||
request = MessagesRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=10,
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
service.create_message(request)
|
||||
|
||||
assert excinfo.value.status_code == 500
|
||||
|
||||
|
||||
def test_parse_cli_event_error_logs_metadata_by_default():
|
||||
"""CLI parser must not log raw error text unless LOG_RAW_CLI_DIAGNOSTICS is on."""
|
||||
from messaging.event_parser import parse_cli_event
|
||||
|
||||
secret = "user-secret-parser-leak-xyz"
|
||||
with patch("messaging.event_parser.logger.info") as log_info:
|
||||
parse_cli_event(
|
||||
{"type": "error", "error": {"message": secret}}, log_raw_cli=False
|
||||
)
|
||||
flat = " ".join(str(c) for c in log_info.call_args_list)
|
||||
assert secret not in flat
|
||||
assert "message_chars" in flat
|
||||
|
||||
|
||||
def test_parse_cli_event_error_logs_text_when_log_raw_cli_enabled():
|
||||
from messaging.event_parser import parse_cli_event
|
||||
|
||||
secret = "visible-cli-parser-msg"
|
||||
with patch("messaging.event_parser.logger.info") as log_info:
|
||||
parse_cli_event(
|
||||
{"type": "error", "error": {"message": secret}}, log_raw_cli=True
|
||||
)
|
||||
flat = " ".join(str(c) for c in log_info.call_args_list)
|
||||
assert secret in flat
|
||||
|
||||
|
||||
def test_count_tokens_unexpected_error_default_logs_exclude_exception_text():
|
||||
settings = Settings()
|
||||
assert settings.log_api_error_tracebacks is False
|
||||
secret = "count-tokens-leak-xyz"
|
||||
|
||||
def boom(*_a, **_kw):
|
||||
raise ValueError(secret)
|
||||
|
||||
service = ClaudeProxyService(
|
||||
settings,
|
||||
provider_getter=lambda _: MagicMock(),
|
||||
token_counter=boom,
|
||||
)
|
||||
from api.models.anthropic import TokenCountRequest
|
||||
|
||||
req = TokenCountRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
messages=[Message(role="user", content="x")],
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(services_mod.logger, "error") as log_err,
|
||||
pytest.raises(HTTPException),
|
||||
):
|
||||
service.count_tokens(req)
|
||||
|
||||
blob = _flatten_log_calls(log_err)
|
||||
assert secret not in blob
|
||||
assert "ValueError" in blob
|
||||
33
tests/api/test_validation_log.py
Normal file
33
tests/api/test_validation_log.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Tests for validation log summaries (metadata only)."""
|
||||
|
||||
from api.validation_log import summarize_request_validation_body
|
||||
|
||||
|
||||
def test_summarize_lists_block_metadata_without_echoing_string_content():
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "secret user phrase",
|
||||
}
|
||||
],
|
||||
"tools": [{"name": "web_search", "type": "web_search_20250305"}],
|
||||
}
|
||||
summary, tool_names = summarize_request_validation_body(body)
|
||||
assert summary == [
|
||||
{
|
||||
"role": "user",
|
||||
"content_kind": "str",
|
||||
"content_length": 18,
|
||||
}
|
||||
]
|
||||
assert tool_names == ["web_search"]
|
||||
blob = repr(summary) + repr(tool_names)
|
||||
assert "secret" not in blob
|
||||
|
||||
|
||||
def test_summarize_handles_non_dict_messages_and_missing_tools():
|
||||
body = {"messages": ["not_a_dict"]}
|
||||
summary, tool_names = summarize_request_validation_body(body)
|
||||
assert summary == [{"message_kind": "str"}]
|
||||
assert tool_names == []
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue