diff --git a/.env.example b/.env.example index 5983188..3dd8593 100644 --- a/.env.example +++ b/.env.example @@ -4,6 +4,8 @@ NVIDIA_NIM_API_KEY="" # OpenRouter Config OPENROUTER_API_KEY="" +# Diagnostic rollback only: "anthropic" (native /messages, default) | "openai" (chat completions) +OPENROUTER_TRANSPORT="anthropic" # DeepSeek Config @@ -55,7 +57,7 @@ HTTP_CONNECT_TIMEOUT=2 ANTHROPIC_AUTH_TOKEN= -# Messaging Platform: "telegram" | "discord" +# Messaging Platform: "telegram" | "discord" | "none" MESSAGING_PLATFORM="discord" MESSAGING_RATE_LIMIT=1 MESSAGING_RATE_WINDOW=1 @@ -88,6 +90,7 @@ ALLOWED_DISCORD_CHANNELS="" # Agent Config CLAUDE_WORKSPACE="./agent_workspace" ALLOWED_DIR="" +CLAUDE_CLI_BIN="claude" FAST_PREFIX_DETECTION=true ENABLE_NETWORK_PROBE_MOCK=true ENABLE_TITLE_GENERATION_SKIP=true diff --git a/AGENTS.md b/AGENTS.md index d43b17c..6b528d9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -31,7 +31,7 @@ - **Performance**: Use list accumulation for strings (not `+=` in loops), cache env vars at init, prefer iterative over recursive when stack depth matters. - **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code. - **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue. -- **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working. +- **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required. ## COGNITIVE WORKFLOW diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..77aae0f --- /dev/null +++ b/PLAN.md @@ -0,0 +1,88 @@ +# Architecture Plan + +This document is the baseline architecture guide referenced by `AGENTS.md` and +`CLAUDE.md`. It records the intended dependency direction and the migration +target for keeping the project modular as providers, clients, and smoke tests +grow. + +## Current Product Shape + +`free-claude-code` is an Anthropic-compatible proxy with optional messaging +workers: + +- `api/` owns the HTTP routes, request orchestration, model routing, auth, and + server lifecycle. +- `providers/` owns upstream model adapters, request conversion, stream + conversion, provider rate limiting, and provider error mapping. +- `messaging/` owns Discord and Telegram adapters, command handling, tree + threading, session persistence, transcript rendering, and voice intake. +- `cli/` owns package entrypoints and managed Claude CLI subprocess sessions. +- `config/` owns environment-backed settings and logging setup. +- `smoke/` owns opt-in product smoke scenarios and the public coverage + inventory used by contract tests. + +## Intended Dependency Direction + +The repo should preserve this dependency order: + +```mermaid +flowchart TD + config[config] --> api[api] + config --> providers[providers] + config --> messaging[messaging] + core[core.anthropic] --> api + core --> providers + core --> messaging + providers --> api + cli --> messaging + messaging --> api +``` + +The practical rule is simpler than the graph: shared protocol helpers belong in +neutral core modules, not under a provider package. Provider adapters may depend +on the neutral protocol layer, but API and messaging code should not import +provider internals. + +## Target Boundaries + +- `core/anthropic/`: Anthropic protocol helpers, stream primitives, content + extraction, token estimation, user-facing error strings, and request + conversion utilities shared across API, providers, messaging, and tests. +- `api/runtime.py`: application composition, optional messaging startup, + session store restoration, and cleanup ownership. +- `providers/`: provider descriptors, credential resolution, transport + factories, scoped rate limiters, upstream request builders, and stream + transformers. +- `messaging/`: platform-neutral orchestration split from command dispatch, + rendering, voice handling, and persistence. +- `cli/`: typed Claude CLI runner config, subprocess management, and packaged + user-facing entrypoints. + +## Smoke Coverage Policy + +Default CI stays deterministic and runs `uv run pytest` against `tests/`. +Product smoke lives under `smoke/` and is enabled with `FCC_LIVE_SMOKE=1`. +Smoke runs should use `-n 0` unless a scenario is explicitly known to be safe +under xdist. + +Live smoke has two valid skip classes: + +- `missing_env`: credentials, local services, binaries, or explicit opt-in flags + are absent. +- `upstream_unavailable`: real providers, bot APIs, or local model servers are + unreachable. + +`product_failure` and `harness_bug` are regressions. When a provider is +explicitly selected by `FCC_SMOKE_PROVIDER_MATRIX`, missing configuration should +fail instead of being silently skipped. + +## Refactor Rules + +- Keep public request/response shapes stable while moving internals. +- Complete module migrations in one change: update imports to the new owner and + remove old compatibility shims unless preserving a published interface is + explicitly required. +- Lock behavior with focused tests before moving shared protocol or runtime + code. +- Run checks in this order: `uv run ruff format`, `uv run ruff check`, + `uv run ty check`, `uv run pytest`. diff --git a/api/app.py b/api/app.py index cf5877c..f6677e3 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,5 @@ """FastAPI application factory and configuration.""" -import asyncio import os from contextlib import asynccontextmanager @@ -14,6 +13,7 @@ from providers.exceptions import ProviderError from .dependencies import cleanup_provider from .routes import router +from .runtime import AppRuntime, warn_if_process_auth_token # Opt-in to future behavior for python-telegram-bot os.environ["PTB_TIMEDELTA"] = "1" @@ -23,174 +23,22 @@ _settings = get_settings() configure_logging(_settings.log_file) -_SHUTDOWN_TIMEOUT_S = 5.0 - - -async def _best_effort( - name: str, awaitable, timeout_s: float = _SHUTDOWN_TIMEOUT_S -) -> None: - """Run a shutdown step with timeout; never raise to callers.""" - try: - await asyncio.wait_for(awaitable, timeout=timeout_s) - except TimeoutError: - logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)") - except Exception as e: - logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}") - - def _warn_if_process_auth_token(settings) -> None: - """Warn when server auth was implicitly inherited from the shell.""" - uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None) - if callable(uses_process_token) and uses_process_token(): - logger.warning( - "ANTHROPIC_AUTH_TOKEN is set in the process environment but not in " - "a configured .env file. The proxy will require that token. Add " - "ANTHROPIC_AUTH_TOKEN= to .env to disable proxy auth, or set the " - "same token in .env to make server auth explicit." - ) + """Compatibility wrapper for tests importing the old app helper.""" + warn_if_process_auth_token(settings) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" - settings = get_settings() - logger.info("Starting Claude Code Proxy...") - _warn_if_process_auth_token(settings) - - # Initialize messaging platform if configured - messaging_platform = None - message_handler = None - cli_manager = None - - try: - # Use the messaging factory to create the right platform - from messaging.platforms.factory import create_messaging_platform - - messaging_platform = create_messaging_platform( - platform_type=settings.messaging_platform, - bot_token=settings.telegram_bot_token, - allowed_user_id=settings.allowed_telegram_user_id, - discord_bot_token=settings.discord_bot_token, - allowed_discord_channels=settings.allowed_discord_channels, - ) - - if messaging_platform: - from cli.manager import CLISessionManager - from messaging.handler import ClaudeMessageHandler - from messaging.session import SessionStore - - # Setup workspace - CLI runs in allowed_dir if set (e.g. project root) - workspace = ( - os.path.abspath(settings.allowed_dir) - if settings.allowed_dir - else os.getcwd() - ) - os.makedirs(workspace, exist_ok=True) - - # Session data stored in agent_workspace - data_path = os.path.abspath(settings.claude_workspace) - os.makedirs(data_path, exist_ok=True) - - api_url = f"http://{settings.host}:{settings.port}/v1" - allowed_dirs = [workspace] if settings.allowed_dir else [] - plans_dir_abs = os.path.abspath( - os.path.join(settings.claude_workspace, "plans") - ) - plans_directory = os.path.relpath(plans_dir_abs, workspace) - cli_manager = CLISessionManager( - workspace_path=workspace, - api_url=api_url, - allowed_dirs=allowed_dirs, - plans_directory=plans_directory, - ) - - # Initialize session store - session_store = SessionStore( - storage_path=os.path.join(data_path, "sessions.json") - ) - - # Create and register message handler - message_handler = ClaudeMessageHandler( - platform=messaging_platform, - cli_manager=cli_manager, - session_store=session_store, - ) - - # Restore tree state if available - saved_trees = session_store.get_all_trees() - if saved_trees: - logger.info(f"Restoring {len(saved_trees)} conversation trees...") - from messaging.trees.queue_manager import TreeQueueManager - - message_handler.replace_tree_queue( - TreeQueueManager.from_dict( - { - "trees": saved_trees, - "node_to_tree": session_store.get_node_mapping(), - }, - queue_update_callback=message_handler.update_queue_positions, - node_started_callback=message_handler.mark_node_processing, - ) - ) - # Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart - if message_handler.tree_queue.cleanup_stale_nodes() > 0: - # Sync back and save - tree_data = message_handler.tree_queue.to_dict() - session_store.sync_from_tree_data( - tree_data["trees"], tree_data["node_to_tree"] - ) - - # Wire up the handler - messaging_platform.on_message(message_handler.handle_message) - - # Start the platform - await messaging_platform.start() - logger.info( - f"{messaging_platform.name} platform started with message handler" - ) - - except ImportError as e: - logger.warning(f"Messaging module import error: {e}") - except Exception as e: - logger.error(f"Failed to start messaging platform: {e}") - import traceback - - logger.error(traceback.format_exc()) - - # Store in app state for access in routes - app.state.messaging_platform = messaging_platform - app.state.message_handler = message_handler - app.state.cli_manager = cli_manager + runtime = AppRuntime.for_app( + app, settings=get_settings(), provider_cleanup=cleanup_provider + ) + await runtime.startup() yield - # Cleanup - if message_handler and hasattr(message_handler, "session_store"): - try: - message_handler.session_store.flush_pending_save() - except Exception as e: - logger.warning(f"Session store flush on shutdown: {e}") - logger.info("Shutdown requested, cleaning up...") - if messaging_platform: - await _best_effort("messaging_platform.stop", messaging_platform.stop()) - if cli_manager: - await _best_effort("cli_manager.stop_all", cli_manager.stop_all()) - await _best_effort("cleanup_provider", cleanup_provider()) - - # Ensure background limiter worker doesn't keep the loop alive. - try: - from messaging.limiter import MessagingRateLimiter - - await _best_effort( - "MessagingRateLimiter.shutdown_instance", - MessagingRateLimiter.shutdown_instance(), - timeout_s=2.0, - ) - except Exception: - # Limiter may never have been imported/initialized. - pass - - logger.info("Server shut down cleanly") + await runtime.shutdown() def create_app() -> FastAPI: diff --git a/api/dependencies.py b/api/dependencies.py index bf54cd5..ede7588 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -5,8 +5,8 @@ from loguru import logger from config.settings import Settings from config.settings import get_settings as _get_settings +from core.anthropic import get_user_facing_error_message from providers.base import BaseProvider -from providers.common import get_user_facing_error_message from providers.exceptions import AuthenticationError from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry @@ -24,6 +24,7 @@ def get_provider_for_type(provider_type: str) -> BaseProvider: Providers are cached in the registry and reused across requests. """ + should_log_init = provider_type not in _providers try: provider = ProviderRegistry(_providers).get(provider_type, get_settings()) except AuthenticationError as e: @@ -37,7 +38,7 @@ def get_provider_for_type(provider_type: str) -> BaseProvider: ", ".join(f"'{key}'" for key in PROVIDER_DESCRIPTORS), ) raise - if provider_type in _providers: + if should_log_init: logger.info("Provider initialized: {}", provider_type) return provider diff --git a/api/detection.py b/api/detection.py index 7df048a..9f3f33f 100644 --- a/api/detection.py +++ b/api/detection.py @@ -4,7 +4,7 @@ Detects quota checks, title generation, prefix detection, suggestion mode, and filepath extraction requests to enable fast-path responses. """ -from providers.common.text import extract_text_from_content +from core.anthropic import extract_text_from_content from .models.anthropic import MessagesRequest diff --git a/api/model_router.py b/api/model_router.py index 8e755e1..789d02e 100644 --- a/api/model_router.py +++ b/api/model_router.py @@ -19,6 +19,18 @@ class ResolvedModel: provider_model_ref: str +@dataclass(frozen=True, slots=True) +class RoutedMessagesRequest: + request: MessagesRequest + resolved: ResolvedModel + + +@dataclass(frozen=True, slots=True) +class RoutedTokenCountRequest: + request: TokenCountRequest + resolved: ResolvedModel + + class ModelRouter: """Resolve incoming Claude model names to configured provider/model pairs.""" @@ -40,19 +52,21 @@ class ModelRouter: provider_model_ref=provider_model_ref, ) - def resolve_messages_request(self, request: MessagesRequest) -> MessagesRequest: - """Return a routed copy of a MessagesRequest.""" - original_model = request.original_model or request.model - resolved = self.resolve(original_model) + def resolve_messages_request( + self, request: MessagesRequest + ) -> RoutedMessagesRequest: + """Return an internal routed request context.""" + resolved = self.resolve(request.model) routed = request.model_copy(deep=True) - routed.original_model = resolved.original_model - routed.resolved_provider_model = resolved.provider_model_ref routed.model = resolved.provider_model - return routed + return RoutedMessagesRequest(request=routed, resolved=resolved) def resolve_token_count_request( self, request: TokenCountRequest - ) -> TokenCountRequest: - """Return a token-count request copy with provider model name applied.""" + ) -> RoutedTokenCountRequest: + """Return an internal token-count request context.""" resolved = self.resolve(request.model) - return request.model_copy(update={"model": resolved.provider_model}, deep=True) + routed = request.model_copy( + update={"model": resolved.provider_model}, deep=True + ) + return RoutedTokenCountRequest(request=routed, resolved=resolved) diff --git a/api/models/anthropic.py b/api/models/anthropic.py index 9758295..a8bbef3 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -103,8 +103,6 @@ class MessagesRequest(BaseModel): tool_choice: dict[str, Any] | None = None thinking: ThinkingConfig | None = None extra_body: dict[str, Any] | None = None - original_model: str | None = None - resolved_provider_model: str | None = None class TokenCountRequest(BaseModel): diff --git a/api/request_utils.py b/api/request_utils.py index 2e8c154..1a6f149 100644 --- a/api/request_utils.py +++ b/api/request_utils.py @@ -1,101 +1,5 @@ -"""Request utility functions for API route handlers. +"""Backward-compatible token counting import for API route handlers.""" -Contains token counting for API requests. -""" - -import json - -import tiktoken -from loguru import logger - -from providers.common import get_block_attr - -ENCODER = tiktoken.get_encoding("cl100k_base") +from core.anthropic import get_token_count __all__ = ["get_token_count"] - - -def get_token_count( - messages: list, - system: str | list | None = None, - tools: list | None = None, -) -> int: - """Estimate token count for a request. - - Uses tiktoken cl100k_base encoding to estimate token usage. - Includes system prompt, messages, tools, and per-message overhead. - """ - total_tokens = 0 - - if system: - if isinstance(system, str): - total_tokens += len(ENCODER.encode(system)) - elif isinstance(system, list): - for block in system: - text = get_block_attr(block, "text", "") - if text: - total_tokens += len(ENCODER.encode(str(text))) - total_tokens += 4 # System block formatting overhead - - for msg in messages: - if isinstance(msg.content, str): - total_tokens += len(ENCODER.encode(msg.content)) - elif isinstance(msg.content, list): - for block in msg.content: - b_type = get_block_attr(block, "type") or None - - if b_type == "text": - text = get_block_attr(block, "text", "") - total_tokens += len(ENCODER.encode(str(text))) - elif b_type == "thinking": - thinking = get_block_attr(block, "thinking", "") - total_tokens += len(ENCODER.encode(str(thinking))) - elif b_type == "tool_use": - name = get_block_attr(block, "name", "") - inp = get_block_attr(block, "input", {}) - block_id = get_block_attr(block, "id", "") - total_tokens += len(ENCODER.encode(str(name))) - total_tokens += len(ENCODER.encode(json.dumps(inp))) - total_tokens += len(ENCODER.encode(str(block_id))) - total_tokens += 15 - elif b_type == "image": - source = get_block_attr(block, "source") - if isinstance(source, dict): - data = source.get("data") or source.get("base64") or "" - if data: - total_tokens += max(85, len(data) // 3000) - else: - total_tokens += 765 - else: - total_tokens += 765 - elif b_type == "tool_result": - content = get_block_attr(block, "content", "") - tool_use_id = get_block_attr(block, "tool_use_id", "") - if isinstance(content, str): - total_tokens += len(ENCODER.encode(content)) - else: - total_tokens += len(ENCODER.encode(json.dumps(content))) - total_tokens += len(ENCODER.encode(str(tool_use_id))) - total_tokens += 8 - else: - logger.debug( - "Unexpected block type %r, falling back to json/str encoding", - b_type, - ) - try: - total_tokens += len(ENCODER.encode(json.dumps(block))) - except TypeError, ValueError: - total_tokens += len(ENCODER.encode(str(block))) - - if tools: - for tool in tools: - tool_str = ( - tool.name + (tool.description or "") + json.dumps(tool.input_schema) - ) - total_tokens += len(ENCODER.encode(tool_str)) - - total_tokens += len(messages) * 4 - if tools: - total_tokens += len(tools) * 5 - - return max(1, total_tokens) diff --git a/api/runtime.py b/api/runtime.py new file mode 100644 index 0000000..bb6d8c8 --- /dev/null +++ b/api/runtime.py @@ -0,0 +1,198 @@ +"""Application runtime composition and lifecycle ownership.""" + +from __future__ import annotations + +import asyncio +import os +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from fastapi import FastAPI +from loguru import logger + +from config.settings import Settings, get_settings + +from .dependencies import cleanup_provider + +_SHUTDOWN_TIMEOUT_S = 5.0 + + +async def best_effort( + name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S +) -> None: + """Run a shutdown step with timeout; never raise to callers.""" + try: + await asyncio.wait_for(awaitable, timeout=timeout_s) + except TimeoutError: + logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)") + except Exception as e: + logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}") + + +def warn_if_process_auth_token(settings: Settings) -> None: + """Warn when server auth was implicitly inherited from the shell.""" + uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None) + if callable(uses_process_token) and uses_process_token(): + logger.warning( + "ANTHROPIC_AUTH_TOKEN is set in the process environment but not in " + "a configured .env file. The proxy will require that token. Add " + "ANTHROPIC_AUTH_TOKEN= to .env to disable proxy auth, or set the " + "same token in .env to make server auth explicit." + ) + + +@dataclass(slots=True) +class AppRuntime: + """Own optional messaging, CLI, session, and provider runtime resources.""" + + app: FastAPI + settings: Settings + provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider + messaging_platform: Any = None + message_handler: Any = None + cli_manager: Any = None + + @classmethod + def for_app( + cls, + app: FastAPI, + settings: Settings | None = None, + provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider, + ) -> AppRuntime: + return cls( + app=app, + settings=settings or get_settings(), + provider_cleanup=provider_cleanup, + ) + + async def startup(self) -> None: + logger.info("Starting Claude Code Proxy...") + warn_if_process_auth_token(self.settings) + await self._start_messaging_if_configured() + self._publish_state() + + async def shutdown(self) -> None: + if self.message_handler and hasattr(self.message_handler, "session_store"): + try: + self.message_handler.session_store.flush_pending_save() + except Exception as e: + logger.warning(f"Session store flush on shutdown: {e}") + + logger.info("Shutdown requested, cleaning up...") + if self.messaging_platform: + await best_effort("messaging_platform.stop", self.messaging_platform.stop()) + if self.cli_manager: + await best_effort("cli_manager.stop_all", self.cli_manager.stop_all()) + await best_effort("cleanup_provider", self.provider_cleanup()) + await self._shutdown_limiter() + logger.info("Server shut down cleanly") + + async def _start_messaging_if_configured(self) -> None: + try: + from messaging.platforms.factory import create_messaging_platform + + self.messaging_platform = create_messaging_platform( + platform_type=self.settings.messaging_platform, + bot_token=self.settings.telegram_bot_token, + allowed_user_id=self.settings.allowed_telegram_user_id, + discord_bot_token=self.settings.discord_bot_token, + allowed_discord_channels=self.settings.allowed_discord_channels, + ) + + if self.messaging_platform: + await self._start_message_handler() + + except ImportError as e: + logger.warning(f"Messaging module import error: {e}") + except Exception as e: + logger.error(f"Failed to start messaging platform: {e}") + import traceback + + logger.error(traceback.format_exc()) + + async def _start_message_handler(self) -> None: + from cli.manager import CLISessionManager + from messaging.handler import ClaudeMessageHandler + from messaging.session import SessionStore + + workspace = ( + os.path.abspath(self.settings.allowed_dir) + if self.settings.allowed_dir + else os.getcwd() + ) + os.makedirs(workspace, exist_ok=True) + + data_path = os.path.abspath(self.settings.claude_workspace) + os.makedirs(data_path, exist_ok=True) + + api_url = f"http://{self.settings.host}:{self.settings.port}/v1" + allowed_dirs = [workspace] if self.settings.allowed_dir else [] + plans_dir_abs = os.path.abspath( + os.path.join(self.settings.claude_workspace, "plans") + ) + plans_directory = os.path.relpath(plans_dir_abs, workspace) + self.cli_manager = CLISessionManager( + workspace_path=workspace, + api_url=api_url, + allowed_dirs=allowed_dirs, + plans_directory=plans_directory, + claude_bin=getattr(self.settings, "claude_cli_bin", "claude"), + ) + + session_store = SessionStore( + storage_path=os.path.join(data_path, "sessions.json") + ) + self.message_handler = ClaudeMessageHandler( + platform=self.messaging_platform, + cli_manager=self.cli_manager, + session_store=session_store, + ) + self._restore_tree_state(session_store) + + self.messaging_platform.on_message(self.message_handler.handle_message) + await self.messaging_platform.start() + logger.info( + f"{self.messaging_platform.name} platform started with message handler" + ) + + def _restore_tree_state(self, session_store: Any) -> None: + saved_trees = session_store.get_all_trees() + if not saved_trees: + return + + logger.info(f"Restoring {len(saved_trees)} conversation trees...") + from messaging.trees.queue_manager import TreeQueueManager + + self.message_handler.replace_tree_queue( + TreeQueueManager.from_dict( + { + "trees": saved_trees, + "node_to_tree": session_store.get_node_mapping(), + }, + queue_update_callback=self.message_handler.update_queue_positions, + node_started_callback=self.message_handler.mark_node_processing, + ) + ) + if self.message_handler.tree_queue.cleanup_stale_nodes() > 0: + tree_data = self.message_handler.tree_queue.to_dict() + session_store.sync_from_tree_data( + tree_data["trees"], tree_data["node_to_tree"] + ) + + def _publish_state(self) -> None: + self.app.state.messaging_platform = self.messaging_platform + self.app.state.message_handler = self.message_handler + self.app.state.cli_manager = self.cli_manager + + async def _shutdown_limiter(self) -> None: + try: + from messaging.limiter import MessagingRateLimiter + + await best_effort( + "MessagingRateLimiter.shutdown_instance", + MessagingRateLimiter.shutdown_instance(), + timeout_s=2.0, + ) + except Exception: + pass diff --git a/api/services.py b/api/services.py index 6138577..8705bf5 100644 --- a/api/services.py +++ b/api/services.py @@ -12,8 +12,8 @@ from fastapi.responses import StreamingResponse from loguru import logger from config.settings import Settings +from core.anthropic import get_user_facing_error_message from providers.base import BaseProvider -from providers.common import get_user_facing_error_message from providers.exceptions import InvalidRequestError, ProviderError from .model_router import ModelRouter @@ -48,35 +48,32 @@ class ClaudeProxyService: if not request_data.messages: raise InvalidRequestError("messages cannot be empty") - routed_request = self._model_router.resolve_messages_request(request_data) + routed = self._model_router.resolve_messages_request(request_data) - optimized = try_optimizations(routed_request, self._settings) + optimized = try_optimizations(routed.request, self._settings) if optimized is not None: return optimized logger.debug("No optimization matched, routing to provider") - provider_type = ( - routed_request.resolved_provider_model or self._settings.model - ).split("/", 1)[0] - provider = self._provider_getter(provider_type) + provider = self._provider_getter(routed.resolved.provider_id) request_id = f"req_{uuid.uuid4().hex[:12]}" logger.info( "API_REQUEST: request_id={} model={} messages={}", request_id, - routed_request.model, - len(routed_request.messages), + routed.request.model, + len(routed.request.messages), ) logger.debug( - "FULL_PAYLOAD [{}]: {}", request_id, routed_request.model_dump() + "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump() ) input_tokens = self._token_counter( - routed_request.messages, routed_request.system, routed_request.tools + routed.request.messages, routed.request.system, routed.request.tools ) return StreamingResponse( provider.stream_response( - routed_request, + routed.request, input_tokens=input_tokens, request_id=request_id, ), @@ -102,17 +99,15 @@ class ClaudeProxyService: request_id = f"req_{uuid.uuid4().hex[:12]}" with logger.contextualize(request_id=request_id): try: - routed_request = self._model_router.resolve_token_count_request( - request_data - ) + routed = self._model_router.resolve_token_count_request(request_data) tokens = self._token_counter( - routed_request.messages, routed_request.system, routed_request.tools + routed.request.messages, routed.request.system, routed.request.tools ) logger.info( "COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}", request_id, - routed_request.model, - len(routed_request.messages), + routed.request.model, + len(routed.request.messages), tokens, ) return TokenCountResponse(input_tokens=tokens) diff --git a/claude-pick b/claude-pick index 792c6bd..409188b 100755 --- a/claude-pick +++ b/claude-pick @@ -13,11 +13,15 @@ OPENROUTER_MODELS_URL="https://openrouter.ai/api/v1/models" DEFAULT_LM_STUDIO_BASE_URL="http://localhost:1234/v1" DEFAULT_LLAMACPP_BASE_URL="http://localhost:8080/v1" -if ! command -v python3 >/dev/null 2>&1; then - echo "Error: python3 is required." >&2 +if ! command -v uv >/dev/null 2>&1; then + echo "Error: uv is required." >&2 exit 1 fi +run_python() { + uv run python "$@" +} + read_env_value() { local key="$1" [[ -f "$ENV_FILE" ]] || return 0 @@ -41,7 +45,7 @@ if ! command -v fzf >/dev/null 2>&1; then fi parse_models_from_json() { - python3 -c ' + run_python -c ' import json, sys try: payload = json.load(sys.stdin) @@ -61,7 +65,7 @@ get_nvidia_models() { exit 1 fi - python3 -c ' + run_python -c ' import json, sys with open(sys.argv[1], "r", encoding="utf-8") as f: payload = json.load(f) diff --git a/cli/manager.py b/cli/manager.py index 48d3849..b267633 100644 --- a/cli/manager.py +++ b/cli/manager.py @@ -28,6 +28,7 @@ class CLISessionManager: api_url: str, allowed_dirs: list[str] | None = None, plans_directory: str | None = None, + claude_bin: str = "claude", ): """ Initialize the session manager. @@ -42,6 +43,7 @@ class CLISessionManager: self.api_url = api_url self.allowed_dirs = allowed_dirs or [] self.plans_directory = plans_directory + self.claude_bin = claude_bin self._sessions: dict[str, CLISession] = {} self._pending_sessions: dict[str, CLISession] = {} @@ -76,6 +78,7 @@ class CLISessionManager: api_url=self.api_url, allowed_dirs=self.allowed_dirs, plans_directory=self.plans_directory, + claude_bin=self.claude_bin, ) self._pending_sessions[temp_id] = new_session logger.info(f"Created new session: {temp_id}") diff --git a/cli/session.py b/cli/session.py index 8847414..db18f6c 100644 --- a/cli/session.py +++ b/cli/session.py @@ -4,6 +4,7 @@ import asyncio import json import os from collections.abc import AsyncGenerator +from dataclasses import dataclass, field from typing import Any from loguru import logger @@ -11,6 +12,17 @@ from loguru import logger from .process_registry import register_pid, unregister_pid +@dataclass(frozen=True, slots=True) +class ClaudeCliConfig: + """Configuration for a managed Claude CLI subprocess.""" + + workspace_path: str + api_url: str + allowed_dirs: list[str] = field(default_factory=list) + plans_directory: str | None = None + claude_bin: str = "claude" + + class CLISession: """Manages a single persistent Claude Code CLI subprocess.""" @@ -20,11 +32,20 @@ class CLISession: api_url: str, allowed_dirs: list[str] | None = None, plans_directory: str | None = None, + claude_bin: str = "claude", ): - self.workspace = os.path.normpath(os.path.abspath(workspace_path)) - self.api_url = api_url - self.allowed_dirs = [os.path.normpath(d) for d in (allowed_dirs or [])] - self.plans_directory = plans_directory + self.config = ClaudeCliConfig( + workspace_path=os.path.normpath(os.path.abspath(workspace_path)), + api_url=api_url, + allowed_dirs=[os.path.normpath(d) for d in (allowed_dirs or [])], + plans_directory=plans_directory, + claude_bin=claude_bin, + ) + self.workspace = self.config.workspace_path + self.api_url = self.config.api_url + self.allowed_dirs = self.config.allowed_dirs + self.plans_directory = self.config.plans_directory + self.claude_bin = self.config.claude_bin self.process: asyncio.subprocess.Process | None = None self.current_session_id: str | None = None self._is_busy = False @@ -67,7 +88,7 @@ class CLISession: # Build command if session_id and not session_id.startswith("pending_"): cmd = [ - "claude", + self.claude_bin, "--resume", session_id, ] @@ -84,7 +105,7 @@ class CLISession: logger.info(f"Resuming Claude session {session_id}") else: cmd = [ - "claude", + self.claude_bin, "-p", prompt, "--output-format", diff --git a/config/env.example b/config/env.example index 37d64f5..3dd8593 100644 --- a/config/env.example +++ b/config/env.example @@ -16,6 +16,10 @@ DEEPSEEK_API_KEY="" LM_STUDIO_BASE_URL="http://localhost:1234/v1" +# Llama.cpp Config (local provider, no API key required) +LLAMACPP_BASE_URL="http://localhost:8080/v1" + + # All Claude model requests are mapped to these models, plain model is fallback # Format: provider_type/model/name # Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp" @@ -25,7 +29,19 @@ MODEL_HAIKU= MODEL="nvidia_nim/z-ai/glm4.7" +# Thinking output +# Global switch for provider reasoning requests and Claude thinking blocks. +# Set false to suppress thinking across NIM, OpenRouter, LM Studio, and llama.cpp. +ENABLE_THINKING=true + + # Provider config +# Per-provider proxy support: http and socks5, example: "http://username:password@host:port" +NVIDIA_NIM_PROXY="" +OPENROUTER_PROXY="" +LMSTUDIO_PROXY="" +LLAMACPP_PROXY="" + PROVIDER_RATE_LIMIT=40 PROVIDER_RATE_WINDOW=60 PROVIDER_MAX_CONCURRENCY=5 @@ -37,7 +53,11 @@ HTTP_WRITE_TIMEOUT=10 HTTP_CONNECT_TIMEOUT=2 -# Messaging Platform: "telegram" | "discord" +# Optional server API key (Anthropic-style) +ANTHROPIC_AUTH_TOKEN= + + +# Messaging Platform: "telegram" | "discord" | "none" MESSAGING_PLATFORM="discord" MESSAGING_RATE_LIMIT=1 MESSAGING_RATE_WINDOW=1 @@ -70,6 +90,7 @@ ALLOWED_DISCORD_CHANNELS="" # Agent Config CLAUDE_WORKSPACE="./agent_workspace" ALLOWED_DIR="" +CLAUDE_CLI_BIN="claude" FAST_PREFIX_DETECTION=true ENABLE_NETWORK_PROBE_MOCK=true ENABLE_TITLE_GENERATION_SKIP=true diff --git a/config/settings.py b/config/settings.py index b346f41..fbfb08a 100644 --- a/config/settings.py +++ b/config/settings.py @@ -99,7 +99,7 @@ class Settings(BaseSettings): deepseek_api_key: str = Field(default="", validation_alias="DEEPSEEK_API_KEY") # ==================== Messaging Platform Selection ==================== - # Valid: "telegram" | "discord" + # Valid: "telegram" | "discord" | "none" messaging_platform: str = Field( default="discord", validation_alias="MESSAGING_PLATFORM" ) @@ -195,6 +195,7 @@ class Settings(BaseSettings): ) claude_workspace: str = "./agent_workspace" allowed_dir: str = "" + claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN") # ==================== Server ==================== host: str = "0.0.0.0" @@ -249,6 +250,15 @@ class Settings(BaseSettings): ) return v + @field_validator("messaging_platform") + @classmethod + def validate_messaging_platform(cls, v: str) -> str: + if v not in ("telegram", "discord", "none"): + raise ValueError( + f"messaging_platform must be 'telegram', 'discord', or 'none', got {v!r}" + ) + return v + @field_validator("model", "model_opus", "model_sonnet", "model_haiku") @classmethod def validate_model_format(cls, v: str | None) -> str | None: diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..716557d --- /dev/null +++ b/core/__init__.py @@ -0,0 +1 @@ +"""Neutral shared application core.""" diff --git a/core/anthropic/__init__.py b/core/anthropic/__init__.py new file mode 100644 index 0000000..9af5987 --- /dev/null +++ b/core/anthropic/__init__.py @@ -0,0 +1,29 @@ +"""Anthropic protocol helpers shared across API, providers, and integrations.""" + +from .content import extract_text_from_content, get_block_attr, get_block_type +from .conversion import AnthropicToOpenAIConverter, build_base_request_body +from .errors import append_request_id, get_user_facing_error_message +from .sse import ContentBlockManager, SSEBuilder, map_stop_reason +from .thinking import ContentChunk, ContentType, ThinkTagParser +from .tokens import get_token_count +from .tools import HeuristicToolParser +from .utils import set_if_not_none + +__all__ = [ + "AnthropicToOpenAIConverter", + "ContentBlockManager", + "ContentChunk", + "ContentType", + "HeuristicToolParser", + "SSEBuilder", + "ThinkTagParser", + "append_request_id", + "build_base_request_body", + "extract_text_from_content", + "get_block_attr", + "get_block_type", + "get_token_count", + "get_user_facing_error_message", + "map_stop_reason", + "set_if_not_none", +] diff --git a/core/anthropic/content.py b/core/anthropic/content.py new file mode 100644 index 0000000..e16aaaf --- /dev/null +++ b/core/anthropic/content.py @@ -0,0 +1,31 @@ +"""Content block helpers for Anthropic-compatible payloads.""" + +from typing import Any + + +def get_block_attr(block: Any, attr: str, default: Any = None) -> Any: + """Get an attribute from a Pydantic model, lightweight object, or dict.""" + if hasattr(block, attr): + return getattr(block, attr) + if isinstance(block, dict): + return block.get(attr, default) + return default + + +def get_block_type(block: Any) -> str | None: + """Return a content block type when present.""" + return get_block_attr(block, "type") + + +def extract_text_from_content(content: Any) -> str: + """Extract concatenated text from message content.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + text = get_block_attr(block, "text", "") + if isinstance(text, str) and text: + parts.append(text) + return "".join(parts) + return "" diff --git a/providers/common/message_converter.py b/core/anthropic/conversion.py similarity index 82% rename from providers/common/message_converter.py rename to core/anthropic/conversion.py index 5e2184f..444df77 100644 --- a/providers/common/message_converter.py +++ b/core/anthropic/conversion.py @@ -3,23 +3,12 @@ import json from typing import Any - -def get_block_attr(block: Any, attr: str, default: Any = None) -> Any: - """Get attribute from object or dict.""" - if hasattr(block, attr): - return getattr(block, attr) - if isinstance(block, dict): - return block.get(attr, default) - return default - - -def get_block_type(block: Any) -> str | None: - """Get block type from object or dict.""" - return get_block_attr(block, "type") +from .content import get_block_attr, get_block_type +from .utils import set_if_not_none class AnthropicToOpenAIConverter: - """Converts Anthropic message format to OpenAI format.""" + """Convert Anthropic message format to OpenAI-compatible format.""" @staticmethod def convert_messages( @@ -29,12 +18,6 @@ class AnthropicToOpenAIConverter: include_reasoning_for_openrouter: bool = False, include_reasoning_content: bool = False, ) -> list[dict[str, Any]]: - """Convert a list of Anthropic messages to OpenAI format. - - When reasoning_content preservation is enabled, assistant messages with - thinking blocks get reasoning_content added for provider multi-turn - reasoning continuation. - """ result = [] for msg in messages: @@ -70,7 +53,6 @@ class AnthropicToOpenAIConverter: include_reasoning_for_openrouter: bool = False, include_reasoning_content: bool = False, ) -> list[dict[str, Any]]: - """Convert assistant message blocks, preserving interleaved thinking+text order.""" content_parts: list[str] = [] thinking_parts: list[str] = [] tool_calls: list[dict[str, Any]] = [] @@ -106,9 +88,6 @@ class AnthropicToOpenAIConverter: ) content_str = "\n\n".join(content_parts) - - # Ensure content is never an empty string for assistant messages - # NIM (especially Mistral models) requires non-empty content if there are no tool calls if not content_str and not tool_calls: content_str = " " @@ -125,7 +104,6 @@ class AnthropicToOpenAIConverter: @staticmethod def _convert_user_message(content: list[Any]) -> list[dict[str, Any]]: - """Convert user message blocks (including tool results), preserving order.""" result: list[dict[str, Any]] = [] text_parts: list[str] = [] @@ -162,7 +140,6 @@ class AnthropicToOpenAIConverter: @staticmethod def convert_tools(tools: list[Any]) -> list[dict[str, Any]]: - """Convert Anthropic tools to OpenAI format.""" return [ { "type": "function", @@ -177,7 +154,6 @@ class AnthropicToOpenAIConverter: @staticmethod def convert_tool_choice(tool_choice: Any) -> Any: - """Convert Anthropic tool_choice to OpenAI-compatible format.""" if not isinstance(tool_choice, dict): return tool_choice @@ -197,10 +173,9 @@ class AnthropicToOpenAIConverter: @staticmethod def convert_system_prompt(system: Any) -> dict[str, str] | None: - """Convert Anthropic system prompt to OpenAI format.""" if isinstance(system, str): return {"role": "system", "content": system} - elif isinstance(system, list): + if isinstance(system, list): text_parts = [ get_block_attr(block, "text", "") for block in system @@ -219,14 +194,7 @@ def build_base_request_body( include_reasoning_for_openrouter: bool = False, include_reasoning_content: bool = False, ) -> dict[str, Any]: - """Build the common parts of an OpenAI-format request body. - - Handles message conversion, system prompt, max_tokens, temperature, - top_p, stop sequences, tools, and tool_choice. Provider-specific - parameters (extra_body, penalties, NIM settings) are added by callers. - """ - from providers.common.utils import set_if_not_none - + """Build the common parts of an OpenAI-format request body.""" messages = AnthropicToOpenAIConverter.convert_messages( request_data.messages, include_thinking=include_thinking, diff --git a/core/anthropic/errors.py b/core/anthropic/errors.py new file mode 100644 index 0000000..38e3d12 --- /dev/null +++ b/core/anthropic/errors.py @@ -0,0 +1,53 @@ +"""User-facing error formatting shared by API, providers, and integrations.""" + +import httpx +import openai + + +def get_user_facing_error_message( + e: Exception, + *, + read_timeout_s: float | None = None, +) -> str: + """Return a readable, non-empty error message for users.""" + message = str(e).strip() + if message: + return message + + if isinstance(e, httpx.ReadTimeout): + if read_timeout_s is not None: + return f"Provider request timed out after {read_timeout_s:g}s." + return "Provider request timed out." + if isinstance(e, httpx.ConnectTimeout): + return "Could not connect to provider." + if isinstance(e, TimeoutError): + if read_timeout_s is not None: + return f"Provider request timed out after {read_timeout_s:g}s." + return "Request timed out." + + name = type(e).__name__ + status_code = getattr(e, "status_code", None) + if isinstance(e, openai.RateLimitError) or name == "RateLimitError": + return "Provider rate limit reached. Please retry shortly." + if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError": + return "Provider authentication failed. Check API key." + if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError": + return "Invalid request sent to provider." + if name == "OverloadedError": + return "Provider is currently overloaded. Please retry." + if name == "APIError": + if status_code in (502, 503, 504): + return "Provider is temporarily unavailable. Please retry." + return "Provider API request failed." + if name.endswith("ProviderError") or name == "ProviderError": + return "Provider request failed." + + return "Provider request failed unexpectedly." + + +def append_request_id(message: str, request_id: str | None) -> str: + """Append request_id suffix when available.""" + base = message.strip() or "Provider request failed unexpectedly." + if request_id: + return f"{base} (request_id={request_id})" + return base diff --git a/providers/common/sse_builder.py b/core/anthropic/sse.py similarity index 80% rename from providers/common/sse_builder.py rename to core/anthropic/sse.py index 9dd2183..7998839 100644 --- a/providers/common/sse_builder.py +++ b/core/anthropic/sse.py @@ -3,7 +3,6 @@ import json from collections.abc import Iterator from dataclasses import dataclass, field -from typing import Any from loguru import logger @@ -15,7 +14,6 @@ except Exception: ENCODER = None -# Map OpenAI finish_reason to Anthropic stop_reason STOP_REASON_MAP = { "stop": "end_turn", "length": "max_tokens", @@ -35,7 +33,7 @@ def map_stop_reason(openai_reason: str | None) -> str: class ToolCallState: """State for a single streaming tool call.""" - block_index: int # -1 if not yet allocated + block_index: int tool_id: str name: str contents: list[str] = field(default_factory=list) @@ -46,7 +44,7 @@ class ToolCallState: @dataclass class ContentBlockManager: - """Manages content block indices and state.""" + """Manage content block indices and state.""" next_index: int = 0 thinking_index: int = -1 @@ -56,17 +54,11 @@ class ContentBlockManager: tool_states: dict[int, ToolCallState] = field(default_factory=dict) def allocate_index(self) -> int: - """Allocate and return the next block index.""" idx = self.next_index self.next_index += 1 return idx def register_tool_name(self, index: int, name: str) -> None: - """Register or merge a streaming tool name fragment. - - Handles providers that stream names as fragments and those that - resend the full name on every chunk. - """ if index not in self.tool_states: self.tool_states[index] = ToolCallState( block_index=-1, tool_id="", name=name @@ -80,11 +72,6 @@ class ContentBlockManager: state.name = prev + name def buffer_task_args(self, index: int, args: str) -> dict | None: - """Buffer Task tool args and return parsed JSON when complete. - - Returns the parsed (and patched) args dict once the buffer forms - valid JSON, or None if still accumulating. - """ state = self.tool_states.get(index) if state is None or state.task_args_emitted: return None @@ -103,7 +90,6 @@ class ContentBlockManager: return args_json def flush_task_arg_buffers(self) -> list[tuple[int, str]]: - """Flush any remaining Task arg buffers. Returns (tool_index, json_str) pairs.""" results: list[tuple[int, str]] = [] for tool_index, state in list(self.tool_states.items()): if not state.task_arg_buffer or state.task_args_emitted: @@ -142,15 +128,12 @@ class SSEBuilder: self._accumulated_text_parts: list[str] = [] self._accumulated_reasoning_parts: list[str] = [] - def _format_event(self, event_type: str, data: dict[str, Any]) -> str: - """Format as SSE string.""" + def _format_event(self, event_type: str, data: dict) -> str: event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip()) return event_str - # Message lifecycle events def message_start(self) -> str: - """Generate message_start event.""" usage = {"input_tokens": self.input_tokens, "output_tokens": 1} return self._format_event( "message_start", @@ -170,7 +153,6 @@ class SSEBuilder: ) def message_delta(self, stop_reason: str, output_tokens: int) -> str: - """Generate message_delta event with stop reason.""" return self._format_event( "message_delta", { @@ -184,13 +166,10 @@ class SSEBuilder: ) def message_stop(self) -> str: - """Generate message_stop event.""" return self._format_event("message_stop", {"type": "message_stop"}) - # Content block events def content_block_start(self, index: int, block_type: str, **kwargs) -> str: - """Generate content_block_start event.""" - content_block: dict[str, Any] = {"type": block_type} + content_block: dict = {"type": block_type} if block_type == "thinking": content_block["thinking"] = kwargs.get("thinking", "") elif block_type == "text": @@ -210,8 +189,7 @@ class SSEBuilder: ) def content_block_delta(self, index: int, delta_type: str, content: str) -> str: - """Generate content_block_delta event.""" - delta: dict[str, Any] = {"type": delta_type} + delta: dict = {"type": delta_type} if delta_type == "thinking_delta": delta["thinking"] = content elif delta_type == "text_delta": @@ -229,7 +207,6 @@ class SSEBuilder: ) def content_block_stop(self, index: int) -> str: - """Generate content_block_stop event.""" return self._format_event( "content_block_stop", { @@ -238,45 +215,35 @@ class SSEBuilder: }, ) - # High-level helpers for thinking blocks def start_thinking_block(self) -> str: - """Start a thinking block, allocating index.""" self.blocks.thinking_index = self.blocks.allocate_index() self.blocks.thinking_started = True return self.content_block_start(self.blocks.thinking_index, "thinking") def emit_thinking_delta(self, content: str) -> str: - """Emit thinking content delta.""" self._accumulated_reasoning_parts.append(content) return self.content_block_delta( self.blocks.thinking_index, "thinking_delta", content ) def stop_thinking_block(self) -> str: - """Stop the current thinking block.""" self.blocks.thinking_started = False return self.content_block_stop(self.blocks.thinking_index) - # High-level helpers for text blocks def start_text_block(self) -> str: - """Start a text block, allocating index.""" self.blocks.text_index = self.blocks.allocate_index() self.blocks.text_started = True return self.content_block_start(self.blocks.text_index, "text") def emit_text_delta(self, content: str) -> str: - """Emit text content delta.""" self._accumulated_text_parts.append(content) return self.content_block_delta(self.blocks.text_index, "text_delta", content) def stop_text_block(self) -> str: - """Stop the current text block.""" self.blocks.text_started = False return self.content_block_stop(self.blocks.text_index) - # High-level helpers for tool blocks def start_tool_block(self, tool_index: int, tool_id: str, name: str) -> str: - """Start a tool_use block.""" block_idx = self.blocks.allocate_index() if tool_index in self.blocks.tool_states: state = self.blocks.tool_states[tool_index] @@ -293,7 +260,6 @@ class SSEBuilder: return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name) def emit_tool_delta(self, tool_index: int, partial_json: str) -> str: - """Emit tool input delta.""" state = self.blocks.tool_states[tool_index] state.contents.append(partial_json) return self.content_block_delta( @@ -301,34 +267,28 @@ class SSEBuilder: ) def stop_tool_block(self, tool_index: int) -> str: - """Stop a tool block.""" block_idx = self.blocks.tool_states[tool_index].block_index return self.content_block_stop(block_idx) - # State management helpers def ensure_thinking_block(self) -> Iterator[str]: - """Ensure a thinking block is started, closing text block if needed.""" if self.blocks.text_started: yield self.stop_text_block() if not self.blocks.thinking_started: yield self.start_thinking_block() def ensure_text_block(self) -> Iterator[str]: - """Ensure a text block is started, closing thinking block if needed.""" if self.blocks.thinking_started: yield self.stop_thinking_block() if not self.blocks.text_started: yield self.start_text_block() def close_content_blocks(self) -> Iterator[str]: - """Close thinking and text blocks (before tool calls).""" if self.blocks.thinking_started: yield self.stop_thinking_block() if self.blocks.text_started: yield self.stop_text_block() def close_all_blocks(self) -> Iterator[str]: - """Close all open blocks (thinking, text, tools).""" if self.blocks.thinking_started: yield self.stop_thinking_block() if self.blocks.text_started: @@ -337,54 +297,45 @@ class SSEBuilder: if state.started: yield self.stop_tool_block(tool_index) - # Error handling def emit_error(self, error_message: str) -> Iterator[str]: - """Emit an error as a text block.""" error_index = self.blocks.allocate_index() yield self.content_block_start(error_index, "text") yield self.content_block_delta(error_index, "text_delta", error_message) yield self.content_block_stop(error_index) - # Accumulated content access @property def accumulated_text(self) -> str: - """Get accumulated text content.""" return "".join(self._accumulated_text_parts) @property def accumulated_reasoning(self) -> str: - """Get accumulated reasoning content.""" return "".join(self._accumulated_reasoning_parts) def estimate_output_tokens(self) -> int: - """Estimate output tokens from accumulated content.""" accumulated_text = self.accumulated_text accumulated_reasoning = self.accumulated_reasoning if ENCODER: text_tokens = len(ENCODER.encode(accumulated_text)) reasoning_tokens = len(ENCODER.encode(accumulated_reasoning)) - # Tool calls are harder to tokenize exactly without reconstruction, but we can approximate - # by tokenizing the json dumps of tool contents tool_tokens = 0 started_tool_count = 0 for state in self.blocks.tool_states.values(): tool_tokens += len(ENCODER.encode(state.name)) tool_tokens += len(ENCODER.encode("".join(state.contents))) - tool_tokens += 15 # Control tokens overhead per tool + tool_tokens += 15 if state.started: started_tool_count += 1 - # Per-block overhead (~4 tokens per content block) block_count = ( (1 if accumulated_reasoning else 0) + (1 if accumulated_text else 0) + started_tool_count ) - block_overhead = block_count * 4 - - return text_tokens + reasoning_tokens + tool_tokens + block_overhead + return text_tokens + reasoning_tokens + tool_tokens + (block_count * 4) text_tokens = len(accumulated_text) // 4 reasoning_tokens = len(accumulated_reasoning) // 4 - tool_tokens = sum(1 for s in self.blocks.tool_states.values() if s.started) * 50 + tool_tokens = ( + sum(1 for state in self.blocks.tool_states.values() if state.started) * 50 + ) return text_tokens + reasoning_tokens + tool_tokens diff --git a/providers/common/think_parser.py b/core/anthropic/thinking.py similarity index 60% rename from providers/common/think_parser.py rename to core/anthropic/thinking.py index 2deed99..fa793d1 100644 --- a/providers/common/think_parser.py +++ b/core/anthropic/thinking.py @@ -1,4 +1,4 @@ -"""Think tag parser for extracting reasoning content from responses.""" +"""Streaming parser for provider-emitted thinking tags.""" from collections.abc import Iterator from dataclasses import dataclass @@ -22,15 +22,13 @@ class ContentChunk: class ThinkTagParser: """ - Streaming parser for ... tags. + Streaming parser for ``...`` tags. Handles partial tags at chunk boundaries by buffering. """ OPEN_TAG = "" CLOSE_TAG = "" - OPEN_TAG_LEN = 7 - CLOSE_TAG_LEN = 8 def __init__(self): self._buffer: str = "" @@ -42,13 +40,7 @@ class ThinkTagParser: return self._in_think_tag def feed(self, content: str) -> Iterator[ContentChunk]: - """ - Feed content and yield parsed chunks. - - Handles partial tags by buffering content near potential tag boundaries. - Uses an iterative loop instead of mutual recursion to avoid stack overflow - on inputs with many consecutive think tags. - """ + """Feed content and yield parsed chunks.""" self._buffer += content while self._buffer: @@ -61,7 +53,6 @@ class ThinkTagParser: if chunk: yield chunk elif len(self._buffer) == prev_len: - # No progress: waiting for more data break def _parse_outside_think(self) -> ContentChunk | None: @@ -69,29 +60,23 @@ class ThinkTagParser: think_start = self._buffer.find(self.OPEN_TAG) orphan_close = self._buffer.find(self.CLOSE_TAG) - # Handle orphan - strip it (Step Fun AI sends reasoning via - # reasoning_content but may leak closing tags in content) if orphan_close != -1 and (think_start == -1 or orphan_close < think_start): pre_orphan = self._buffer[:orphan_close] - self._buffer = self._buffer[orphan_close + self.CLOSE_TAG_LEN :] + self._buffer = self._buffer[orphan_close + len(self.CLOSE_TAG) :] if pre_orphan: return ContentChunk(ContentType.TEXT, pre_orphan) - # Buffer shrunk; the feed() loop will continue parsing return None if think_start == -1: - # No tag found - check for partial tag at end - # We buffer any trailing '<' and subsequent characters that could be part of or last_bracket = self._buffer.rfind("<") if last_bracket != -1: potential_tag = self._buffer[last_bracket:] tag_len = len(potential_tag) - # Check if could be partial or if ( - tag_len < self.OPEN_TAG_LEN + tag_len < len(self.OPEN_TAG) and self.OPEN_TAG.startswith(potential_tag) ) or ( - tag_len < self.CLOSE_TAG_LEN + tag_len < len(self.CLOSE_TAG) and self.CLOSE_TAG.startswith(potential_tag) ): emit = self._buffer[:last_bracket] @@ -100,35 +85,28 @@ class ThinkTagParser: return ContentChunk(ContentType.TEXT, emit) return None - # No partial tag found or it's irrelevant emit = self._buffer self._buffer = "" if emit: return ContentChunk(ContentType.TEXT, emit) return None - else: - # Found tag - pre_think = self._buffer[:think_start] - self._buffer = self._buffer[think_start + self.OPEN_TAG_LEN :] - self._in_think_tag = True - if pre_think: - return ContentChunk(ContentType.TEXT, pre_think) - # Buffer shrunk (consumed ); the feed() loop will continue - # parsing inside the think tag on the next iteration - return None + + pre_think = self._buffer[:think_start] + self._buffer = self._buffer[think_start + len(self.OPEN_TAG) :] + self._in_think_tag = True + if pre_think: + return ContentChunk(ContentType.TEXT, pre_think) + return None def _parse_inside_think(self) -> ContentChunk | None: """Parse content inside think tags.""" think_end = self._buffer.find(self.CLOSE_TAG) if think_end == -1: - # No closing tag - check for partial at end last_bracket = self._buffer.rfind("<") - if ( - last_bracket != -1 - and len(self._buffer) - last_bracket < self.CLOSE_TAG_LEN + if last_bracket != -1 and len(self._buffer) - last_bracket < len( + self.CLOSE_TAG ): - # Check if the partial string could be the start of potential_tag = self._buffer[last_bracket:] if self.CLOSE_TAG.startswith(potential_tag): emit = self._buffer[:last_bracket] @@ -142,16 +120,13 @@ class ThinkTagParser: if emit: return ContentChunk(ContentType.THINKING, emit) return None - else: - # Found tag - thinking_content = self._buffer[:think_end] - self._buffer = self._buffer[think_end + self.CLOSE_TAG_LEN :] - self._in_think_tag = False - if thinking_content: - return ContentChunk(ContentType.THINKING, thinking_content) - # Buffer shrunk (consumed ); the feed() loop will continue - # parsing outside the think tag on the next iteration - return None + + thinking_content = self._buffer[:think_end] + self._buffer = self._buffer[think_end + len(self.CLOSE_TAG) :] + self._in_think_tag = False + if thinking_content: + return ContentChunk(ContentType.THINKING, thinking_content) + return None def flush(self) -> ContentChunk | None: """Flush any remaining buffered content.""" diff --git a/core/anthropic/tokens.py b/core/anthropic/tokens.py new file mode 100644 index 0000000..dfa1115 --- /dev/null +++ b/core/anthropic/tokens.py @@ -0,0 +1,92 @@ +"""Token estimation for Anthropic-compatible requests.""" + +import json + +import tiktoken +from loguru import logger + +from .content import get_block_attr + +ENCODER = tiktoken.get_encoding("cl100k_base") + + +def get_token_count( + messages: list, + system: str | list | None = None, + tools: list | None = None, +) -> int: + """Estimate token count for a request.""" + total_tokens = 0 + + if system: + if isinstance(system, str): + total_tokens += len(ENCODER.encode(system)) + elif isinstance(system, list): + for block in system: + text = get_block_attr(block, "text", "") + if text: + total_tokens += len(ENCODER.encode(str(text))) + total_tokens += 4 + + for msg in messages: + if isinstance(msg.content, str): + total_tokens += len(ENCODER.encode(msg.content)) + elif isinstance(msg.content, list): + for block in msg.content: + b_type = get_block_attr(block, "type") or None + + if b_type == "text": + text = get_block_attr(block, "text", "") + total_tokens += len(ENCODER.encode(str(text))) + elif b_type == "thinking": + thinking = get_block_attr(block, "thinking", "") + total_tokens += len(ENCODER.encode(str(thinking))) + elif b_type == "tool_use": + name = get_block_attr(block, "name", "") + inp = get_block_attr(block, "input", {}) + block_id = get_block_attr(block, "id", "") + total_tokens += len(ENCODER.encode(str(name))) + total_tokens += len(ENCODER.encode(json.dumps(inp))) + total_tokens += len(ENCODER.encode(str(block_id))) + total_tokens += 15 + elif b_type == "image": + source = get_block_attr(block, "source") + if isinstance(source, dict): + data = source.get("data") or source.get("base64") or "" + if data: + total_tokens += max(85, len(data) // 3000) + else: + total_tokens += 765 + else: + total_tokens += 765 + elif b_type == "tool_result": + content = get_block_attr(block, "content", "") + tool_use_id = get_block_attr(block, "tool_use_id", "") + if isinstance(content, str): + total_tokens += len(ENCODER.encode(content)) + else: + total_tokens += len(ENCODER.encode(json.dumps(content))) + total_tokens += len(ENCODER.encode(str(tool_use_id))) + total_tokens += 8 + else: + logger.debug( + "Unexpected block type %r, falling back to json/str encoding", + b_type, + ) + try: + total_tokens += len(ENCODER.encode(json.dumps(block))) + except TypeError, ValueError: + total_tokens += len(ENCODER.encode(str(block))) + + if tools: + for tool in tools: + tool_str = ( + tool.name + (tool.description or "") + json.dumps(tool.input_schema) + ) + total_tokens += len(ENCODER.encode(tool_str)) + + total_tokens += len(messages) * 4 + if tools: + total_tokens += len(tools) * 5 + + return max(1, total_tokens) diff --git a/providers/common/heuristic_tool_parser.py b/core/anthropic/tools.py similarity index 59% rename from providers/common/heuristic_tool_parser.py rename to core/anthropic/tools.py index e338db5..0751d9d 100644 --- a/providers/common/heuristic_tool_parser.py +++ b/core/anthropic/tools.py @@ -1,3 +1,5 @@ +"""Heuristic parser for text-emitted tool calls.""" + import re import uuid from enum import Enum @@ -5,9 +7,6 @@ from typing import Any from loguru import logger -# Some OpenAI-compatible backends/models occasionally leak internal sentinel tokens -# into `delta.content` (e.g. "<|tool_call_end|>"). These should never be shown to -# end users, and they can disrupt downstream parsing if left in place. _CONTROL_TOKEN_RE = re.compile(r"<\|[^|>]{1,80}\|>") _CONTROL_TOKEN_START = "<|" _CONTROL_TOKEN_END = "|>" @@ -21,14 +20,13 @@ class ParserState(Enum): class HeuristicToolParser: """ - Stateful parser that detects raw text tool calls in the format: - ● value... + Stateful parser for raw text tool calls. - This is used as a fallback for models that emit tool calls as text - instead of using the structured API. + Some OpenAI-compatible models emit tool calls as text rather than structured + chunks. This parser converts the common ``● `` form into + Anthropic-style ``tool_use`` blocks. """ - # Class-level compiled patterns (compiled once, not per instance) _FUNC_START_PATTERN = re.compile(r"●\s*]+)>") _PARAM_PATTERN = re.compile( r"]+)>(.*?)(?:|$)", re.DOTALL @@ -42,17 +40,9 @@ class HeuristicToolParser: self._current_parameters = {} def _strip_control_tokens(self, text: str) -> str: - # Remove complete sentinel tokens. If a token is split across chunks it - # will be removed once the buffer contains the full token. return _CONTROL_TOKEN_RE.sub("", text) def _split_incomplete_control_token_tail(self) -> str: - """ - If the buffer ends with an incomplete "<|...|>" sentinel token, keep that - fragment in the buffer and return the safe-to-emit prefix. - - This prevents leaking raw sentinel fragments to the user when streaming. - """ start = self._buffer.rfind(_CONTROL_TOKEN_START) if start == -1: return "" @@ -65,13 +55,7 @@ class HeuristicToolParser: return prefix def feed(self, text: str) -> tuple[str, list[dict[str, Any]]]: - """ - Feed text into the parser. - Returns a tuple of (filtered_text, detected_tool_calls). - - filtered_text: Text that should be passed through as normal message content. - detected_tools: List of Anthropic-format tool_use blocks. - """ + """Feed text and return safe text plus detected tool calls.""" self._buffer += text self._buffer = self._strip_control_tokens(self._buffer) detected_tools = [] @@ -79,15 +63,12 @@ class HeuristicToolParser: while True: if self._state == ParserState.TEXT: - # Look for the trigger character if "●" in self._buffer: idx = self._buffer.find("●") filtered_output_parts.append(self._buffer[:idx]) self._buffer = self._buffer[idx:] self._state = ParserState.MATCHING_FUNCTION else: - # Avoid emitting an incomplete "<|...|>" sentinel fragment if the - # token got split across streaming chunks. safe_prefix = self._split_incomplete_control_token_tail() if safe_prefix: filtered_output_parts.append(safe_prefix) @@ -98,46 +79,30 @@ class HeuristicToolParser: break if self._state == ParserState.MATCHING_FUNCTION: - # We need enough buffer to match the function tag - # e.g. "● " match = self._FUNC_START_PATTERN.search(self._buffer) if match: self._current_function_name = match.group(1).strip() self._current_tool_id = f"toolu_heuristic_{uuid.uuid4().hex[:8]}" self._current_parameters = {} - - # Consume the function start from buffer self._buffer = self._buffer[match.end() :] self._state = ParserState.PARSING_PARAMETERS logger.debug( "Heuristic bypass: Detected start of tool call '{}'", self._current_function_name, ) + elif len(self._buffer) > 100: + filtered_output_parts.append(self._buffer[0]) + self._buffer = self._buffer[1:] + self._state = ParserState.TEXT else: - # If we have "●" but not the full tag yet, wait for more data - # Unless the buffer has grown too large without a match - if len(self._buffer) > 100: - # Probably not a tool call, treat as text - filtered_output_parts.append(self._buffer[0]) - self._buffer = self._buffer[1:] - self._state = ParserState.TEXT - else: - break + break if self._state == ParserState.PARSING_PARAMETERS: - # Look for parameters. We look for to know a param is complete. - # Or wait for another " in param_match.group(0): - # Detect any content before the parameter match and preserve it pre_match_text = self._buffer[: param_match.start()] if pre_match_text: filtered_output_parts.append(pre_match_text) @@ -149,32 +114,19 @@ class HeuristicToolParser: else: break - # Heuristic for completion: - # 1. We have at least one param and we see a character that doesn't belong to the format - # 2. Significant pause (not handled here, handled by caller via flush if needed) - # 3. Another ● character (start of NEXT tool call) - if "●" in self._buffer: - # Next tool call starting or something else, close current - # But first, capture any text before the ● idx = self._buffer.find("●") if idx > 0: filtered_output_parts.append(self._buffer[:idx]) self._buffer = self._buffer[idx:] finished_tool_call = True elif len(self._buffer) > 0 and not self._buffer.strip().startswith("<"): - # We have text that doesn't look like a tag, and we already parsed some or are in param state - # Let's see if we have trailing param starts if " list[dict[str, Any]]: - """ - Flush any remaining tool calls in the buffer. - """ + """Flush any remaining tool call in the buffer.""" self._buffer = self._strip_control_tokens(self._buffer) detected_tools = [] if self._state == ParserState.PARSING_PARAMETERS: - # Try to extract any partial parameters remaining in buffer - # Even without partial_matches = re.finditer( r"]+)>(.*)$", self._buffer, re.DOTALL ) - for m in partial_matches: - key = m.group(1).strip() - val = m.group(2).strip() + for match in partial_matches: + key = match.group(1).strip() + val = match.group(2).strip() self._current_parameters[key] = val detected_tools.append( diff --git a/providers/common/utils.py b/core/anthropic/utils.py similarity index 55% rename from providers/common/utils.py rename to core/anthropic/utils.py index 7f290ea..84d33de 100644 --- a/providers/common/utils.py +++ b/core/anthropic/utils.py @@ -1,9 +1,9 @@ -"""Shared utility helpers for provider request builders.""" +"""Small shared protocol utility helpers.""" from typing import Any def set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None: - """Set body[key] = value only when value is not None.""" + """Set ``body[key]`` only when value is not None.""" if value is not None: body[key] = value diff --git a/messaging/command_dispatcher.py b/messaging/command_dispatcher.py new file mode 100644 index 0000000..b1188b9 --- /dev/null +++ b/messaging/command_dispatcher.py @@ -0,0 +1,48 @@ +"""Command parsing and dispatch for messaging handlers.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Protocol + +from .models import IncomingMessage + +CommandHandler = Callable[[IncomingMessage], Awaitable[None]] + + +class SupportsMessagingCommands(Protocol): + async def _handle_clear_command(self, incoming: IncomingMessage) -> None: ... + + async def _handle_stop_command(self, incoming: IncomingMessage) -> None: ... + + async def _handle_stats_command(self, incoming: IncomingMessage) -> None: ... + + +def parse_command_base(text: str | None) -> str: + """Return the slash command without bot mention suffix.""" + parts = (text or "").strip().split() + cmd = parts[0] if parts else "" + return cmd.split("@", 1)[0] if cmd else "" + + +def message_kind_for_command(command_base: str) -> str: + """Return the persistence kind for an incoming message.""" + return "command" if command_base.startswith("/") else "content" + + +async def dispatch_command( + handler: SupportsMessagingCommands, + incoming: IncomingMessage, + command_base: str, +) -> bool: + """Dispatch a known command and return whether it was handled.""" + commands: dict[str, CommandHandler] = { + "/clear": handler._handle_clear_command, + "/stop": handler._handle_stop_command, + "/stats": handler._handle_stats_command, + } + command = commands.get(command_base) + if command is None: + return False + await command(incoming) + return True diff --git a/messaging/handler.py b/messaging/handler.py index be617a8..52290e0 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -12,8 +12,13 @@ import time from loguru import logger -from providers.common import get_user_facing_error_message +from core.anthropic import get_user_facing_error_message +from .command_dispatcher import ( + dispatch_command, + message_kind_for_command, + parse_command_base, +) from .commands import ( handle_clear_command, handle_stats_command, @@ -155,36 +160,23 @@ class ClaudeMessageHandler: async def _handle_message_impl(self, incoming: IncomingMessage) -> None: """Implementation of handle_message with context bound.""" - # Check for commands - parts = (incoming.text or "").strip().split() - cmd = parts[0] if parts else "" - cmd_base = cmd.split("@", 1)[0] if cmd else "" + cmd_base = parse_command_base(incoming.text) # Record incoming message ID for best-effort UI clearing (/clear), even if # we later ignore this message (status/command/etc). try: if incoming.message_id is not None: - kind = "command" if cmd_base.startswith("/") else "content" self.session_store.record_message_id( incoming.platform, incoming.chat_id, str(incoming.message_id), direction="in", - kind=kind, + kind=message_kind_for_command(cmd_base), ) except Exception as e: logger.debug(f"Failed to record incoming message_id: {e}") - if cmd_base == "/clear": - await self._handle_clear_command(incoming) - return - - if cmd_base == "/stop": - await self._handle_stop_command(incoming) - return - - if cmd_base == "/stats": - await self._handle_stats_command(incoming) + if await dispatch_command(self, incoming, cmd_base): return # Filter out status messages (our own messages) diff --git a/messaging/platforms/discord.py b/messaging/platforms/discord.py index 97661c9..3222c20 100644 --- a/messaging/platforms/discord.py +++ b/messaging/platforms/discord.py @@ -14,11 +14,11 @@ from typing import Any, cast from loguru import logger -from providers.common import get_user_facing_error_message +from core.anthropic import get_user_facing_error_message from ..models import IncomingMessage from ..rendering.discord_markdown import format_status_discord -from ..voice import PendingVoiceRegistry, VoiceTranscriptionService +from ..voice_pipeline import VoiceNotePipeline from .base import MessagingPlatform AUDIO_EXTENSIONS = (".ogg", ".mp4", ".mp3", ".wav", ".m4a") @@ -116,8 +116,7 @@ class DiscordPlatform(MessagingPlatform): self._connected = False self._limiter: Any | None = None self._start_task: asyncio.Task | None = None - self._pending_voice = PendingVoiceRegistry() - self._voice_transcription = VoiceTranscriptionService() + self._voice_pipeline = VoiceNotePipeline() async def _handle_client_message(self, message: Any) -> None: """Adapter entry point used by the internal discord client.""" @@ -127,17 +126,19 @@ class DiscordPlatform(MessagingPlatform): self, chat_id: str, voice_msg_id: str, status_msg_id: str ) -> None: """Register a voice note as pending transcription.""" - await self._pending_voice.register(chat_id, voice_msg_id, status_msg_id) + await self._voice_pipeline.register_pending( + chat_id, voice_msg_id, status_msg_id + ) async def cancel_pending_voice( self, chat_id: str, reply_id: str ) -> tuple[str, str] | None: """Cancel a pending voice transcription. Returns (voice_msg_id, status_msg_id) if found.""" - return await self._pending_voice.cancel(chat_id, reply_id) + return await self._voice_pipeline.cancel_pending(chat_id, reply_id) async def _is_voice_still_pending(self, chat_id: str, voice_msg_id: str) -> bool: """Check if a voice note is still pending (not cancelled).""" - return await self._pending_voice.is_pending(chat_id, voice_msg_id) + return await self._voice_pipeline.is_pending(chat_id, voice_msg_id) def _get_audio_attachment(self, message: Any) -> Any | None: """Return first audio attachment, or None.""" @@ -198,7 +199,7 @@ class DiscordPlatform(MessagingPlatform): try: await attachment.save(str(tmp_path)) - transcribed = await self._voice_transcription.transcribe( + transcribed = await self._voice_pipeline.transcribe( tmp_path, ct, whisper_model=settings.whisper_model, @@ -209,7 +210,7 @@ class DiscordPlatform(MessagingPlatform): await self.queue_delete_message(channel_id, str(status_msg_id)) return True - await self._pending_voice.complete( + await self._voice_pipeline.complete( channel_id, message_id, str(status_msg_id) ) diff --git a/messaging/platforms/factory.py b/messaging/platforms/factory.py index f260c3b..4b4bf27 100644 --- a/messaging/platforms/factory.py +++ b/messaging/platforms/factory.py @@ -24,6 +24,10 @@ def create_messaging_platform( Returns: Configured MessagingPlatform instance, or None if not configured. """ + if platform_type == "none": + logger.info("Messaging platform disabled by configuration") + return None + if platform_type == "telegram": bot_token = kwargs.get("bot_token") if not bot_token: diff --git a/messaging/platforms/telegram.py b/messaging/platforms/telegram.py index 9eadc4c..602b622 100644 --- a/messaging/platforms/telegram.py +++ b/messaging/platforms/telegram.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any from loguru import logger -from providers.common import get_user_facing_error_message +from core.anthropic import get_user_facing_error_message if TYPE_CHECKING: from telegram import Update @@ -27,7 +27,7 @@ if TYPE_CHECKING: from ..models import IncomingMessage from ..rendering.telegram_markdown import escape_md_v2, format_status -from ..voice import PendingVoiceRegistry, VoiceTranscriptionService +from ..voice_pipeline import VoiceNotePipeline from .base import MessagingPlatform # Optional import - python-telegram-bot may not be installed @@ -83,24 +83,25 @@ class TelegramPlatform(MessagingPlatform): self._connected = False self._limiter: Any | None = None # Will be MessagingRateLimiter # Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id) - self._pending_voice = PendingVoiceRegistry() - self._voice_transcription = VoiceTranscriptionService() + self._voice_pipeline = VoiceNotePipeline() async def _register_pending_voice( self, chat_id: str, voice_msg_id: str, status_msg_id: str ) -> None: """Register a voice note as pending transcription (for /clear reply during transcription).""" - await self._pending_voice.register(chat_id, voice_msg_id, status_msg_id) + await self._voice_pipeline.register_pending( + chat_id, voice_msg_id, status_msg_id + ) async def cancel_pending_voice( self, chat_id: str, reply_id: str ) -> tuple[str, str] | None: """Cancel a pending voice transcription. Returns (voice_msg_id, status_msg_id) if found.""" - return await self._pending_voice.cancel(chat_id, reply_id) + return await self._voice_pipeline.cancel_pending(chat_id, reply_id) async def _is_voice_still_pending(self, chat_id: str, voice_msg_id: str) -> bool: """Check if a voice note is still pending (not cancelled).""" - return await self._pending_voice.is_pending(chat_id, voice_msg_id) + return await self._voice_pipeline.is_pending(chat_id, voice_msg_id) async def start(self) -> None: """Initialize and connect to Telegram.""" @@ -597,7 +598,7 @@ class TelegramPlatform(MessagingPlatform): tg_file = await context.bot.get_file(voice.file_id) await tg_file.download_to_drive(custom_path=str(tmp_path)) - transcribed = await self._voice_transcription.transcribe( + transcribed = await self._voice_pipeline.transcribe( tmp_path, voice.mime_type or "audio/ogg", whisper_model=settings.whisper_model, @@ -608,7 +609,7 @@ class TelegramPlatform(MessagingPlatform): await self.queue_delete_message(chat_id, str(status_msg_id)) return - await self._pending_voice.complete(chat_id, message_id, str(status_msg_id)) + await self._voice_pipeline.complete(chat_id, message_id, str(status_msg_id)) incoming = IncomingMessage( text=transcribed, diff --git a/messaging/trees/processor.py b/messaging/trees/processor.py index d9255b4..5fdcca6 100644 --- a/messaging/trees/processor.py +++ b/messaging/trees/processor.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable from loguru import logger -from providers.common import get_user_facing_error_message +from core.anthropic import get_user_facing_error_message from .data import MessageNode, MessageState, MessageTree diff --git a/messaging/trees/queue_manager.py b/messaging/trees/queue_manager.py index 14b0f10..fbf7029 100644 --- a/messaging/trees/queue_manager.py +++ b/messaging/trees/queue_manager.py @@ -14,14 +14,6 @@ from .data import MessageNode, MessageState, MessageTree from .processor import TreeQueueProcessor from .repository import TreeRepository -# Backward compatibility: re-export moved classes -__all__ = [ - "MessageNode", - "MessageState", - "MessageTree", - "TreeQueueManager", -] - class TreeQueueManager: """ diff --git a/messaging/voice_pipeline.py b/messaging/voice_pipeline.py new file mode 100644 index 0000000..4c4fc85 --- /dev/null +++ b/messaging/voice_pipeline.py @@ -0,0 +1,53 @@ +"""Platform-neutral voice note pipeline.""" + +from __future__ import annotations + +from pathlib import Path + +from .voice import PendingVoiceRegistry, VoiceTranscriptionService + + +class VoiceNotePipeline: + """Coordinate pending voice state and transcription backend execution.""" + + def __init__( + self, + *, + registry: PendingVoiceRegistry | None = None, + transcription: VoiceTranscriptionService | None = None, + ) -> None: + self._registry = registry or PendingVoiceRegistry() + self._transcription = transcription or VoiceTranscriptionService() + + async def register_pending( + self, chat_id: str, voice_msg_id: str, status_msg_id: str + ) -> None: + await self._registry.register(chat_id, voice_msg_id, status_msg_id) + + async def cancel_pending( + self, chat_id: str, reply_id: str + ) -> tuple[str, str] | None: + return await self._registry.cancel(chat_id, reply_id) + + async def is_pending(self, chat_id: str, voice_msg_id: str) -> bool: + return await self._registry.is_pending(chat_id, voice_msg_id) + + async def complete( + self, chat_id: str, voice_msg_id: str, status_msg_id: str + ) -> None: + await self._registry.complete(chat_id, voice_msg_id, status_msg_id) + + async def transcribe( + self, + file_path: Path, + mime_type: str, + *, + whisper_model: str, + whisper_device: str, + ) -> str: + return await self._transcription.transcribe( + file_path, + mime_type, + whisper_model=whisper_model, + whisper_device=whisper_device, + ) diff --git a/providers/anthropic_messages.py b/providers/anthropic_messages.py index acb0db5..a884a03 100644 --- a/providers/anthropic_messages.py +++ b/providers/anthropic_messages.py @@ -9,8 +9,9 @@ from typing import Any, Literal import httpx from loguru import logger +from core.anthropic import get_user_facing_error_message from providers.base import BaseProvider, ProviderConfig -from providers.common import get_user_facing_error_message, map_error +from providers.error_mapping import map_error from providers.rate_limit import GlobalRateLimiter ANTHROPIC_DEFAULT_MAX_TOKENS = 81920 diff --git a/providers/common/__init__.py b/providers/common/__init__.py deleted file mode 100644 index 4e6c128..0000000 --- a/providers/common/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Shared provider utilities used by NIM, OpenRouter, and LM Studio.""" - -from .error_mapping import append_request_id, get_user_facing_error_message, map_error -from .heuristic_tool_parser import HeuristicToolParser -from .message_converter import ( - AnthropicToOpenAIConverter, - build_base_request_body, - get_block_attr, - get_block_type, -) -from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason -from .think_parser import ContentChunk, ContentType, ThinkTagParser -from .utils import set_if_not_none - -__all__ = [ - "AnthropicToOpenAIConverter", - "ContentBlockManager", - "ContentChunk", - "ContentType", - "HeuristicToolParser", - "SSEBuilder", - "ThinkTagParser", - "append_request_id", - "build_base_request_body", - "get_block_attr", - "get_block_type", - "get_user_facing_error_message", - "map_error", - "map_stop_reason", - "set_if_not_none", -] diff --git a/providers/common/error_mapping.py b/providers/common/error_mapping.py deleted file mode 100644 index d0c0d06..0000000 --- a/providers/common/error_mapping.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Error mapping for OpenAI-compatible providers (NIM, OpenRouter, LM Studio).""" - -import httpx -import openai - -from providers.exceptions import ( - APIError, - AuthenticationError, - InvalidRequestError, - OverloadedError, - ProviderError, - RateLimitError, -) -from providers.rate_limit import GlobalRateLimiter - - -def get_user_facing_error_message( - e: Exception, - *, - read_timeout_s: float | None = None, -) -> str: - """Return a readable, non-empty error message for users.""" - message = str(e).strip() - if message: - return message - - if isinstance(e, httpx.ReadTimeout): - if read_timeout_s is not None: - return f"Provider request timed out after {read_timeout_s:g}s." - return "Provider request timed out." - if isinstance(e, httpx.ConnectTimeout): - return "Could not connect to provider." - if isinstance(e, TimeoutError): - if read_timeout_s is not None: - return f"Provider request timed out after {read_timeout_s:g}s." - return "Request timed out." - - if isinstance(e, (RateLimitError, openai.RateLimitError)): - return "Provider rate limit reached. Please retry shortly." - if isinstance(e, (AuthenticationError, openai.AuthenticationError)): - return "Provider authentication failed. Check API key." - if isinstance(e, (InvalidRequestError, openai.BadRequestError)): - return "Invalid request sent to provider." - if isinstance(e, OverloadedError): - return "Provider is currently overloaded. Please retry." - if isinstance(e, APIError): - if e.status_code in (502, 503, 504): - return "Provider is temporarily unavailable. Please retry." - return "Provider API request failed." - if isinstance(e, ProviderError): - return "Provider request failed." - - return "Provider request failed unexpectedly." - - -def append_request_id(message: str, request_id: str | None) -> str: - """Append request_id suffix when available.""" - base = message.strip() or "Provider request failed unexpectedly." - if request_id: - return f"{base} (request_id={request_id})" - return base - - -def map_error( - e: Exception, *, rate_limiter: GlobalRateLimiter | None = None -) -> Exception: - """Map OpenAI or HTTPX exception to specific ProviderError.""" - message = get_user_facing_error_message(e) - limiter = rate_limiter or GlobalRateLimiter.get_instance() - - # Map OpenAI Specific Errors - if isinstance(e, openai.AuthenticationError): - return AuthenticationError(message, raw_error=str(e)) - if isinstance(e, openai.RateLimitError): - # Trigger global rate limit block - limiter.set_blocked(60) # Default 60s cooldown - return RateLimitError(message, raw_error=str(e)) - if isinstance(e, openai.BadRequestError): - return InvalidRequestError(message, raw_error=str(e)) - if isinstance(e, openai.InternalServerError): - raw_message = str(e) - if "overloaded" in raw_message.lower() or "capacity" in raw_message.lower(): - return OverloadedError(message, raw_error=raw_message) - return APIError(message, status_code=500, raw_error=str(e)) - if isinstance(e, openai.APIError): - return APIError( - message, status_code=getattr(e, "status_code", 500), raw_error=str(e) - ) - - # Map raw HTTPX Errors - if isinstance(e, httpx.HTTPStatusError): - status = e.response.status_code - if status in (401, 403): - return AuthenticationError(message, raw_error=str(e)) - if status == 429: - limiter.set_blocked(60) - return RateLimitError(message, raw_error=str(e)) - if status == 400: - return InvalidRequestError(message, raw_error=str(e)) - if status >= 500: - if status in (502, 503, 504): - return OverloadedError(message, raw_error=str(e)) - return APIError(message, status_code=status, raw_error=str(e)) - return APIError(message, status_code=status, raw_error=str(e)) - - return e diff --git a/providers/common/text.py b/providers/common/text.py deleted file mode 100644 index ce89627..0000000 --- a/providers/common/text.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Shared text extraction utilities.""" - -from typing import Any - - -def extract_text_from_content(content: Any) -> str: - """Extract concatenated text from message content (str or list of content blocks).""" - if isinstance(content, str): - return content - if isinstance(content, list): - parts = [] - for block in content: - text = getattr(block, "text", "") - if text and isinstance(text, str): - parts.append(text) - return "".join(parts) - return "" diff --git a/providers/deepseek/request.py b/providers/deepseek/request.py index e274be7..67cabf9 100644 --- a/providers/deepseek/request.py +++ b/providers/deepseek/request.py @@ -4,7 +4,7 @@ from typing import Any from loguru import logger -from providers.common.message_converter import build_base_request_body +from core.anthropic import build_base_request_body def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: diff --git a/providers/error_mapping.py b/providers/error_mapping.py new file mode 100644 index 0000000..3826a84 --- /dev/null +++ b/providers/error_mapping.py @@ -0,0 +1,56 @@ +"""Provider-specific exception mapping.""" + +import httpx +import openai + +from core.anthropic import get_user_facing_error_message +from providers.exceptions import ( + APIError, + AuthenticationError, + InvalidRequestError, + OverloadedError, + RateLimitError, +) +from providers.rate_limit import GlobalRateLimiter + + +def map_error( + e: Exception, *, rate_limiter: GlobalRateLimiter | None = None +) -> Exception: + """Map OpenAI or HTTPX exception to specific ProviderError.""" + message = get_user_facing_error_message(e) + limiter = rate_limiter or GlobalRateLimiter.get_instance() + + if isinstance(e, openai.AuthenticationError): + return AuthenticationError(message, raw_error=str(e)) + if isinstance(e, openai.RateLimitError): + limiter.set_blocked(60) + return RateLimitError(message, raw_error=str(e)) + if isinstance(e, openai.BadRequestError): + return InvalidRequestError(message, raw_error=str(e)) + if isinstance(e, openai.InternalServerError): + raw_message = str(e) + if "overloaded" in raw_message.lower() or "capacity" in raw_message.lower(): + return OverloadedError(message, raw_error=raw_message) + return APIError(message, status_code=500, raw_error=str(e)) + if isinstance(e, openai.APIError): + return APIError( + message, status_code=getattr(e, "status_code", 500), raw_error=str(e) + ) + + if isinstance(e, httpx.HTTPStatusError): + status = e.response.status_code + if status in (401, 403): + return AuthenticationError(message, raw_error=str(e)) + if status == 429: + limiter.set_blocked(60) + return RateLimitError(message, raw_error=str(e)) + if status == 400: + return InvalidRequestError(message, raw_error=str(e)) + if status >= 500: + if status in (502, 503, 504): + return OverloadedError(message, raw_error=str(e)) + return APIError(message, status_code=status, raw_error=str(e)) + return APIError(message, status_code=status, raw_error=str(e)) + + return e diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index 0c2f26d..c0a83f3 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -6,8 +6,7 @@ from typing import Any from loguru import logger from config.nim import NimSettings -from providers.common.message_converter import build_base_request_body -from providers.common.utils import set_if_not_none +from core.anthropic import build_base_request_body, set_if_not_none def _set_extra( diff --git a/providers/open_router/chat_request.py b/providers/open_router/chat_request.py index 3e9d2ba..41359be 100644 --- a/providers/open_router/chat_request.py +++ b/providers/open_router/chat_request.py @@ -4,7 +4,7 @@ from typing import Any from loguru import logger -from providers.common.message_converter import build_base_request_body +from core.anthropic import build_base_request_body OPENROUTER_DEFAULT_MAX_TOKENS = 81920 diff --git a/providers/open_router/client.py b/providers/open_router/client.py index d6e46a7..2572874 100644 --- a/providers/open_router/client.py +++ b/providers/open_router/client.py @@ -8,9 +8,9 @@ from collections.abc import Iterator from dataclasses import dataclass, field from typing import Any +from core.anthropic import SSEBuilder, append_request_id from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode from providers.base import ProviderConfig -from providers.common import SSEBuilder, append_request_id from providers.openai_compat import OpenAIChatTransport from .chat_request import build_chat_request_body diff --git a/providers/openai_compat.py b/providers/openai_compat.py index 4125d84..6ca1079 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -10,17 +10,17 @@ import httpx from loguru import logger from openai import AsyncOpenAI -from providers.base import BaseProvider, ProviderConfig -from providers.common import ( +from core.anthropic import ( ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser, append_request_id, get_user_facing_error_message, - map_error, map_stop_reason, ) +from providers.base import BaseProvider, ProviderConfig +from providers.error_mapping import map_error from providers.rate_limit import GlobalRateLimiter diff --git a/providers/rate_limit.py b/providers/rate_limit.py index 7a523f6..0e427e9 100644 --- a/providers/rate_limit.py +++ b/providers/rate_limit.py @@ -50,6 +50,7 @@ class GlobalRateLimiter: self._rate_limit = rate_limit self._rate_window = float(rate_window) + self._max_concurrency = max_concurrency # Monotonic timestamps of the last granted slots. self._request_times: deque[float] = deque() self._blocked_until: float = 0 @@ -95,10 +96,21 @@ class GlobalRateLimiter: """Get or create a provider-scoped limiter instance.""" if not scope: raise ValueError("scope must be non-empty") - if scope not in cls._scoped_instances: + desired_rate_limit = rate_limit or 40 + desired_rate_window = float(rate_window or 60.0) + existing = cls._scoped_instances.get(scope) + if existing and existing.matches_config( + desired_rate_limit, desired_rate_window, max_concurrency + ): + return existing + if existing: + logger.info( + "Rebuilding provider rate limiter for updated scope '{}'", scope + ) + if scope not in cls._scoped_instances or existing: cls._scoped_instances[scope] = cls( - rate_limit=rate_limit or 40, - rate_window=rate_window or 60.0, + rate_limit=desired_rate_limit, + rate_window=desired_rate_window, max_concurrency=max_concurrency, ) return cls._scoped_instances[scope] @@ -174,6 +186,16 @@ class GlobalRateLimiter: """Check if currently reactively blocked.""" return time.monotonic() < self._blocked_until + def matches_config( + self, rate_limit: int, rate_window: float, max_concurrency: int + ) -> bool: + """Return whether this limiter matches the requested runtime config.""" + return ( + self._rate_limit == rate_limit + and self._rate_window == float(rate_window) + and self._max_concurrency == max_concurrency + ) + def remaining_wait(self) -> float: """Get remaining reactive wait time in seconds.""" return max(0.0, self._blocked_until - time.monotonic()) diff --git a/providers/registry.py b/providers/registry.py index 8a63a3f..a0284c7 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import MutableMapping +from collections.abc import Callable, MutableMapping from dataclasses import dataclass from typing import Literal @@ -20,6 +20,7 @@ from providers.open_router import ( ) TransportType = Literal["openai_chat", "anthropic_messages"] +ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider] @dataclass(frozen=True, slots=True) @@ -80,6 +81,37 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = { } +def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider: + return NvidiaNimProvider(config, nim_settings=settings.nim) + + +def _create_open_router(config: ProviderConfig, settings: Settings) -> BaseProvider: + if settings.openrouter_transport == "openai": + return OpenRouterChatProvider(config) + return OpenRouterProvider(config) + + +def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider: + return DeepSeekProvider(config) + + +def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider: + return LMStudioProvider(config) + + +def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider: + return LlamaCppProvider(config) + + +PROVIDER_FACTORIES: dict[str, ProviderFactory] = { + "nvidia_nim": _create_nvidia_nim, + "open_router": _create_open_router, + "deepseek": _create_deepseek, + "lmstudio": _create_lmstudio, + "llamacpp": _create_llamacpp, +} + + def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str: if attr_name is None: return default @@ -144,20 +176,10 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider: ) config = build_provider_config(descriptor, settings) - if provider_id == "nvidia_nim": - return NvidiaNimProvider(config, nim_settings=settings.nim) - if provider_id == "open_router": - if settings.openrouter_transport == "openai": - return OpenRouterChatProvider(config) - return OpenRouterProvider(config) - if provider_id == "deepseek": - return DeepSeekProvider(config) - if provider_id == "lmstudio": - return LMStudioProvider(config) - if provider_id == "llamacpp": - return LlamaCppProvider(config) - - raise AssertionError(f"Unhandled provider descriptor: {provider_id}") + factory = PROVIDER_FACTORIES.get(provider_id) + if factory is None: + raise AssertionError(f"Unhandled provider descriptor: {provider_id}") + return factory(config, settings) class ProviderRegistry: diff --git a/pyproject.toml b/pyproject.toml index 7b145bb..3dd2dd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ voice_local = [ ] [tool.hatch.build.targets.wheel] -packages = ["api", "cli", "config", "messaging", "providers"] +packages = ["api", "cli", "config", "core", "messaging", "providers"] [tool.uv.sources] torch = { index = "pytorch-cu130" } @@ -86,7 +86,7 @@ ignore = [ ] [tool.ruff.lint.isort] -known-first-party = ["api", "cli", "config", "messaging", "providers", "utils"] +known-first-party = ["api", "cli", "config", "core", "messaging", "providers", "smoke"] [tool.ruff.format] quote-style = "double" diff --git a/smoke/README.md b/smoke/README.md index 8e5dfaf..769a9c2 100644 --- a/smoke/README.md +++ b/smoke/README.md @@ -18,7 +18,7 @@ belong under `tests/` and must stay green with plain `uv run pytest`. ```powershell uv run pytest smoke --collect-only -q -uv run pytest smoke -n 0 -m live -s --tb=short +uv run pytest smoke -n 0 -s --tb=short ``` The second command skips everything unless `FCC_LIVE_SMOKE=1` is set, but still @@ -28,7 +28,7 @@ writes skip entries to `.smoke-results/`. ```powershell $env:FCC_LIVE_SMOKE = "1" -uv run pytest smoke -n 0 -m live -s --tb=short +uv run pytest smoke -n 0 -s --tb=short ``` Provider product E2E runs for every configured provider model from `MODEL`, @@ -68,20 +68,20 @@ Side-effectful targets are opt-in: ```powershell $env:FCC_LIVE_SMOKE = "1" $env:FCC_SMOKE_PROVIDER_MATRIX = "open_router,nvidia_nim,deepseek,lmstudio,llamacpp" -uv run pytest smoke/product -n 0 -m live -s --tb=short +uv run pytest smoke/product -n 0 -s --tb=short ``` ```powershell $env:FCC_LIVE_SMOKE = "1" $env:FCC_SMOKE_TARGETS = "telegram,discord,voice" $env:FCC_SMOKE_RUN_VOICE = "1" -uv run pytest smoke/product -n 0 -m live -s --tb=short +uv run pytest smoke/product -n 0 -s --tb=short ``` ```powershell $env:FCC_LIVE_SMOKE = "1" $env:FCC_SMOKE_TARGETS = "messaging,config,extensibility" -uv run pytest smoke/product -n 0 -m live -s --tb=short +uv run pytest smoke/product -n 0 -s --tb=short ``` ## Environment diff --git a/smoke/capabilities.py b/smoke/capabilities.py index 86e1fd6..30772a7 100644 --- a/smoke/capabilities.py +++ b/smoke/capabilities.py @@ -149,7 +149,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = ( "streaming_conversion", "anthropic_sse_lifecycle", "streaming_error_mapping", - "providers.common.SSEBuilder", + "core.anthropic.SSEBuilder", "provider stream chunks or native SSE events", "Anthropic message/content/error SSE lifecycle", "Anthropic-compatible error event", @@ -162,7 +162,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = ( "streaming_conversion", "thinking_blocks", "thinking_token_support", - "providers.common.think_parser", + "core.anthropic.thinking", "reasoning_content, reasoning_details, text, native thinking", "Claude thinking blocks or suppression", "thinking hidden when disabled", @@ -175,7 +175,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = ( "streaming_conversion", "heuristic_tools", "heuristic_tool_parser", - "providers.common.heuristic_tool_parser", + "core.anthropic.tools", "textual tool-call output", "structured Anthropic tool_use blocks", "text fallback when malformed", @@ -186,7 +186,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = ( "streaming_conversion", "subagent_task_control", "subagent_control", - "providers.common.SSEBuilder", + "core.anthropic.SSEBuilder", "Task tool call arguments", "run_in_background=false", "invalid JSON flushed as safe object", diff --git a/smoke/lib/report_summary.py b/smoke/lib/report_summary.py new file mode 100644 index 0000000..3c91644 --- /dev/null +++ b/smoke/lib/report_summary.py @@ -0,0 +1,53 @@ +"""Summarize smoke JSON reports for local and workflow triage.""" + +from __future__ import annotations + +import json +from collections import Counter +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True, slots=True) +class SmokeSummary: + reports: int + outcomes: int + classifications: dict[str, int] + + @property + def has_regression(self) -> bool: + return bool( + self.classifications.get("product_failure", 0) + or self.classifications.get("harness_bug", 0) + ) + + +def summarize_reports(results_dir: Path) -> SmokeSummary: + """Read all report JSON files and count outcome classifications.""" + counts: Counter[str] = Counter() + reports = 0 + outcomes = 0 + for path in sorted(results_dir.glob("report-*.json")): + reports += 1 + payload = json.loads(path.read_text(encoding="utf-8")) + for outcome in payload.get("outcomes", []): + if not isinstance(outcome, dict): + continue + outcomes += 1 + counts[str(outcome.get("classification") or "unknown")] += 1 + return SmokeSummary( + reports=reports, + outcomes=outcomes, + classifications=dict(sorted(counts.items())), + ) + + +def format_summary(summary: SmokeSummary) -> str: + """Return a compact human-readable summary.""" + parts = [ + f"reports={summary.reports}", + f"outcomes={summary.outcomes}", + ] + parts.extend(f"{name}={count}" for name, count in summary.classifications.items()) + status = "regression" if summary.has_regression else "ok" + return f"smoke_summary status={status} " + " ".join(parts) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 9f83a64..cc53294 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -96,7 +96,6 @@ def test_model_mapping(): assert len(_stream_response_calls) == 1 args = _stream_response_calls[0][0] assert args[0].model != "claude-3-haiku-20240307" - assert args[0].original_model == "claude-3-haiku-20240307" def test_error_fallbacks(): diff --git a/tests/api/test_model_router.py b/tests/api/test_model_router.py index 52d107b..e51324d 100644 --- a/tests/api/test_model_router.py +++ b/tests/api/test_model_router.py @@ -36,9 +36,9 @@ def test_model_router_applies_opus_override(settings): ) routed = ModelRouter(settings).resolve_messages_request(request) - assert routed.model == "deepseek/deepseek-r1" - assert routed.resolved_provider_model == "open_router/deepseek/deepseek-r1" - assert routed.original_model == "claude-opus-4-20250514" + assert routed.request.model == "deepseek/deepseek-r1" + assert routed.resolved.provider_model_ref == "open_router/deepseek/deepseek-r1" + assert routed.resolved.original_model == "claude-opus-4-20250514" assert request.model == "claude-opus-4-20250514" @@ -53,8 +53,8 @@ def test_model_router_applies_haiku_override(settings): ) ) - assert routed.model == "qwen2.5-7b" - assert routed.resolved_provider_model == "lmstudio/qwen2.5-7b" + assert routed.request.model == "qwen2.5-7b" + assert routed.resolved.provider_model_ref == "lmstudio/qwen2.5-7b" def test_model_router_applies_sonnet_override(settings): @@ -68,8 +68,10 @@ def test_model_router_applies_sonnet_override(settings): ) ) - assert routed.model == "meta/llama-3.3-70b-instruct" - assert routed.resolved_provider_model == "nvidia_nim/meta/llama-3.3-70b-instruct" + assert routed.request.model == "meta/llama-3.3-70b-instruct" + assert ( + routed.resolved.provider_model_ref == "nvidia_nim/meta/llama-3.3-70b-instruct" + ) def test_model_router_routes_token_count_request(settings): @@ -81,7 +83,7 @@ def test_model_router_routes_token_count_request(settings): ) routed = ModelRouter(settings).resolve_token_count_request(request) - assert routed.model == "qwen2.5-7b" + assert routed.request.model == "qwen2.5-7b" assert request.model == "claude-3-haiku-20240307" diff --git a/tests/api/test_models_validators.py b/tests/api/test_models_validators.py index d74da70..4e7740b 100644 --- a/tests/api/test_models_validators.py +++ b/tests/api/test_models_validators.py @@ -9,22 +9,22 @@ def test_messages_request_parses_without_model_mapping_side_effects(): ) assert request.model == "claude-3-opus" - assert request.original_model is None - assert request.resolved_provider_model is None -def test_messages_request_preserves_internal_routing_fields_when_supplied(): - request = MessagesRequest( - model="target-model", - original_model="claude-3-opus", - resolved_provider_model="nvidia_nim/target-model", - max_tokens=100, - messages=[Message(role="user", content="hello")], +def test_messages_request_ignores_internal_routing_fields_when_supplied(): + request = MessagesRequest.model_validate( + { + "model": "target-model", + "original_model": "claude-3-opus", + "resolved_provider_model": "nvidia_nim/target-model", + "max_tokens": 100, + "messages": [{"role": "user", "content": "hello"}], + } ) assert request.model == "target-model" - assert request.original_model == "claude-3-opus" - assert request.resolved_provider_model == "nvidia_nim/target-model" + assert "original_model" not in request.model_dump() + assert "resolved_provider_model" not in request.model_dump() def test_token_count_request_parses_without_model_mapping_side_effects(): diff --git a/tests/cli/test_cli_ownership.py b/tests/cli/test_cli_ownership.py new file mode 100644 index 0000000..7d16876 --- /dev/null +++ b/tests/cli/test_cli_ownership.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from pathlib import Path + +from cli.session import CLISession + + +def test_cli_session_owns_typed_runner_config(tmp_path: Path) -> None: + session = CLISession( + workspace_path=str(tmp_path), + api_url="http://127.0.0.1:8082/v1", + allowed_dirs=[str(tmp_path)], + plans_directory=".plans", + claude_bin="claude-test", + ) + + assert session.config.workspace_path == str(tmp_path) + assert session.config.api_url == "http://127.0.0.1:8082/v1" + assert session.config.allowed_dirs == [str(tmp_path)] + assert session.config.plans_directory == ".plans" + assert session.config.claude_bin == "claude-test" + + +def test_claude_pick_uses_project_python_runner() -> None: + script = Path(__file__).resolve().parents[2] / "claude-pick" + text = script.read_text(encoding="utf-8") + + assert "uv run python" in text + assert "python3 -c" not in text diff --git a/tests/contracts/test_architecture_contracts.py b/tests/contracts/test_architecture_contracts.py new file mode 100644 index 0000000..ce98264 --- /dev/null +++ b/tests/contracts/test_architecture_contracts.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import re +from pathlib import Path + + +def test_architecture_plan_exists() -> None: + repo_root = Path(__file__).resolve().parents[2] + plan = repo_root / "PLAN.md" + + assert plan.exists() + text = plan.read_text(encoding="utf-8") + assert "Intended Dependency Direction" in text + assert "Smoke Coverage Policy" in text + + +def test_env_examples_are_in_sync() -> None: + repo_root = Path(__file__).resolve().parents[2] + root_example = repo_root / ".env.example" + packaged_example = repo_root / "config" / "env.example" + + assert _env_keys(root_example) == _env_keys(packaged_example) + + +def test_pyproject_first_party_packages_match_packaged_roots() -> None: + repo_root = Path(__file__).resolve().parents[2] + pyproject = (repo_root / "pyproject.toml").read_text(encoding="utf-8") + match = re.search(r"known-first-party = \[(?P[^\]]+)\]", pyproject) + + assert match is not None + configured = { + item.strip().strip('"') + for item in match.group("items").split(",") + if item.strip() + } + expected = {"api", "cli", "config", "core", "messaging", "providers", "smoke"} + assert configured == expected + + +def _env_keys(path: Path) -> set[str]: + keys: set[str] = set() + for line in path.read_text(encoding="utf-8").splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#") or "=" not in stripped: + continue + key, _, _value = stripped.partition("=") + keys.add(key.strip()) + return keys diff --git a/tests/contracts/test_import_boundaries.py b/tests/contracts/test_import_boundaries.py new file mode 100644 index 0000000..bf33eae --- /dev/null +++ b/tests/contracts/test_import_boundaries.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import ast +from pathlib import Path + + +def test_api_and_messaging_do_not_import_provider_common() -> None: + repo_root = Path(__file__).resolve().parents[2] + assert not (repo_root / "providers" / "common").exists() + offenders = _imports_matching( + [repo_root / "api", repo_root / "messaging"], + forbidden_prefixes=("providers.common",), + ) + + assert offenders == [] + + +def test_provider_adapters_do_not_import_runtime_layers() -> None: + repo_root = Path(__file__).resolve().parents[2] + offenders = _imports_matching( + [repo_root / "providers"], + forbidden_prefixes=("api.", "messaging.", "cli."), + ) + + assert offenders == [] + + +def test_architecture_doc_names_enforced_boundaries() -> None: + repo_root = Path(__file__).resolve().parents[2] + text = (repo_root / "PLAN.md").read_text(encoding="utf-8") + + assert "core/anthropic/" in text + assert "api/runtime.py" in text + assert "import-boundary" in text or "Provider adapters may depend" in text + + +def _imports_matching( + roots: list[Path], *, forbidden_prefixes: tuple[str, ...] +) -> list[str]: + offenders: list[str] = [] + for root in roots: + for path in root.rglob("*.py"): + rel = path.relative_to(root.parent) + offenders.extend( + f"{rel}: {imported}" + for imported in _imports_from(path) + if imported in forbidden_prefixes + or imported.startswith(forbidden_prefixes) + ) + return sorted(offenders) + + +def _imports_from(path: Path) -> list[str]: + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + imports: list[str] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imports.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module: + imports.append(node.module) + return imports diff --git a/tests/contracts/test_smoke_tiers.py b/tests/contracts/test_smoke_tiers.py new file mode 100644 index 0000000..f343d88 --- /dev/null +++ b/tests/contracts/test_smoke_tiers.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from smoke.lib.report_summary import format_summary, summarize_reports + + +def test_smoke_readme_uses_env_gated_serial_commands() -> None: + repo_root = Path(__file__).resolve().parents[2] + text = (repo_root / "smoke" / "README.md").read_text(encoding="utf-8") + + assert "FCC_LIVE_SMOKE=1" in text + assert "-n 0" in text + assert "-m live" not in text + + +def test_smoke_report_summary_counts_regression_classes(tmp_path: Path) -> None: + report = { + "outcomes": [ + {"classification": "missing_env"}, + {"classification": "product_failure"}, + {"classification": "upstream_unavailable"}, + ] + } + (tmp_path / "report-one.json").write_text(json.dumps(report), encoding="utf-8") + + summary = summarize_reports(tmp_path) + + assert summary.reports == 1 + assert summary.outcomes == 3 + assert summary.classifications["product_failure"] == 1 + assert summary.has_regression + assert "status=regression" in format_summary(summary) diff --git a/tests/contracts/test_stream_contracts.py b/tests/contracts/test_stream_contracts.py index cff9eb3..f6e964f 100644 --- a/tests/contracts/test_stream_contracts.py +++ b/tests/contracts/test_stream_contracts.py @@ -2,14 +2,9 @@ from __future__ import annotations from collections.abc import Iterable +from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser from messaging.event_parser import parse_cli_event from messaging.transcript import RenderCtx, TranscriptBuffer -from providers.common import ( - ContentType, - HeuristicToolParser, - SSEBuilder, - ThinkTagParser, -) from smoke.lib.sse import ( assert_anthropic_stream_contract, event_names, diff --git a/tests/messaging/test_extract_text.py b/tests/messaging/test_extract_text.py index b223459..417f9ef 100644 --- a/tests/messaging/test_extract_text.py +++ b/tests/messaging/test_extract_text.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock import pytest -from providers.common.text import extract_text_from_content +from core.anthropic import extract_text_from_content class TestExtractTextFromContent: - """Tests for providers.common.text.extract_text_from_content.""" + """Tests for core.anthropic.extract_text_from_content.""" def test_string_content(self): """Return string content as-is.""" diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py index 95ca23f..f9a0faf 100644 --- a/tests/providers/test_converter.py +++ b/tests/providers/test_converter.py @@ -2,7 +2,7 @@ import json import pytest -from providers.common.message_converter import AnthropicToOpenAIConverter +from core.anthropic import AnthropicToOpenAIConverter # --- Mock Classes --- @@ -314,7 +314,7 @@ def test_convert_mixed_blocks_and_types_and_roles(): def test_get_block_attr_defaults(): # Test helper directly - from providers.common.message_converter import get_block_attr + from core.anthropic import get_block_attr assert get_block_attr({}, "missing", "default") == "default" assert get_block_attr(object(), "missing", "default") == "default" diff --git a/tests/providers/test_error_mapping.py b/tests/providers/test_error_mapping.py index 7e6cf99..162b05d 100644 --- a/tests/providers/test_error_mapping.py +++ b/tests/providers/test_error_mapping.py @@ -1,4 +1,4 @@ -"""Tests for providers/nvidia_nim/errors.py error mapping.""" +"""Tests for provider error mapping and core error formatting.""" from unittest.mock import MagicMock, patch @@ -6,7 +6,8 @@ import openai import pytest from httpx import ReadTimeout, Request, Response -from providers.common import append_request_id, get_user_facing_error_message, map_error +from core.anthropic import append_request_id, get_user_facing_error_message +from providers.error_mapping import map_error from providers.exceptions import ( APIError, AuthenticationError, @@ -41,7 +42,7 @@ class TestMapError: def test_rate_limit_error(self): """openai.RateLimitError -> RateLimitError and triggers global block.""" exc = _make_openai_error(openai.RateLimitError, status_code=429) - with patch("providers.common.error_mapping.GlobalRateLimiter") as mock_rl: + with patch("providers.error_mapping.GlobalRateLimiter") as mock_rl: mock_instance = MagicMock() mock_rl.get_instance.return_value = mock_instance result = map_error(exc) @@ -117,7 +118,7 @@ class TestMapError: openai.BadRequestError: 400, } exc = _make_openai_error(exc_cls, status_code=status_map[exc_cls]) - with patch("providers.common.error_mapping.GlobalRateLimiter"): + with patch("providers.error_mapping.GlobalRateLimiter"): result = map_error(exc) assert isinstance(result, expected_cls) diff --git a/tests/providers/test_nvidia_nim_request.py b/tests/providers/test_nvidia_nim_request.py index e4b9953..df38e24 100644 --- a/tests/providers/test_nvidia_nim_request.py +++ b/tests/providers/test_nvidia_nim_request.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import pytest from config.nim import NimSettings -from providers.common.utils import set_if_not_none +from core.anthropic import set_if_not_none from providers.nvidia_nim.request import ( _set_extra, build_request_body, diff --git a/tests/providers/test_parsers.py b/tests/providers/test_parsers.py index e84edaa..0b36dc7 100644 --- a/tests/providers/test_parsers.py +++ b/tests/providers/test_parsers.py @@ -1,6 +1,6 @@ import pytest -from providers.common import ContentType, HeuristicToolParser, ThinkTagParser +from core.anthropic import ContentType, HeuristicToolParser, ThinkTagParser def test_think_tag_parser_basic(): diff --git a/tests/providers/test_sse_builder.py b/tests/providers/test_sse_builder.py index 5c2400f..07df2f3 100644 --- a/tests/providers/test_sse_builder.py +++ b/tests/providers/test_sse_builder.py @@ -1,15 +1,11 @@ -"""Tests for providers/nvidia_nim/utils/sse_builder.py.""" +"""Tests for core.anthropic.sse.""" import json from unittest.mock import patch import pytest -from providers.common.sse_builder import ( - ContentBlockManager, - SSEBuilder, - map_stop_reason, -) +from core.anthropic import ContentBlockManager, SSEBuilder, map_stop_reason def _parse_sse(sse_str: str) -> dict: @@ -366,7 +362,7 @@ class TestSSEBuilderTokenEstimation: builder.start_text_block() builder.emit_text_delta("a" * 100) # 100 chars -> ~25 tokens - with patch("providers.common.sse_builder.ENCODER", None): + with patch("core.anthropic.sse.ENCODER", None): tokens = builder.estimate_output_tokens() assert tokens == 25 # 100 // 4 @@ -376,7 +372,7 @@ class TestSSEBuilderTokenEstimation: builder.start_tool_block(0, "t1", "Read") builder.emit_tool_delta(0, '{"path":"test.py"}') - with patch("providers.common.sse_builder.ENCODER", None): + with patch("core.anthropic.sse.ENCODER", None): tokens = builder.estimate_output_tokens() # 1 tool * 50 = 50 assert tokens == 50 diff --git a/tests/providers/test_streaming_errors.py b/tests/providers/test_streaming_errors.py index 04e9ee9..0b5f614 100644 --- a/tests/providers/test_streaming_errors.py +++ b/tests/providers/test_streaming_errors.py @@ -392,7 +392,7 @@ class TestProcessToolCall: def test_tool_call_with_id(self): """Tool call with id starts a tool block.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc = { @@ -409,7 +409,7 @@ class TestProcessToolCall: def test_tool_call_without_id_generates_uuid(self): """Tool call without id generates a uuid-based id.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc = { @@ -424,7 +424,7 @@ class TestProcessToolCall: def test_task_tool_forces_background_false(self): """Task tool with run_in_background=true is forced to false.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") args = json.dumps({"run_in_background": True, "prompt": "test"}) @@ -441,7 +441,7 @@ class TestProcessToolCall: def test_task_tool_chunked_args_forces_background_false(self): """Chunked Task args are buffered until valid JSON, then forced to false.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc1 = { @@ -466,7 +466,7 @@ class TestProcessToolCall: def test_task_tool_invalid_json_logs_warning_on_flush(self, caplog): """Invalid JSON args for Task tool emits {} on flush and logs a warning.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc = { @@ -486,7 +486,7 @@ class TestProcessToolCall: def test_negative_tool_index_fallback(self): """tc_index < 0 uses len(tool_indices) as fallback.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc = { @@ -501,7 +501,7 @@ class TestProcessToolCall: def test_tool_args_emitted_as_delta(self): """Arguments are emitted as input_json_delta events.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc = { @@ -621,7 +621,7 @@ class TestStreamChunkEdgeCases: def test_stream_malformed_tool_args_chunked(self): """Chunked tool args that never form valid JSON are flushed with {}.""" provider = _make_provider() - from providers.common import SSEBuilder + from core.anthropic import SSEBuilder sse = SSEBuilder("msg_test", "test-model") tc1 = { diff --git a/tests/providers/test_subagent_interception.py b/tests/providers/test_subagent_interception.py index a36897a..136c3f4 100644 --- a/tests/providers/test_subagent_interception.py +++ b/tests/providers/test_subagent_interception.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from config.nim import NimSettings +from core.anthropic import ContentBlockManager from providers.base import ProviderConfig -from providers.common import ContentBlockManager from providers.nvidia_nim import NvidiaNimProvider