diff --git a/.env.example b/.env.example
index 5983188..3dd8593 100644
--- a/.env.example
+++ b/.env.example
@@ -4,6 +4,8 @@ NVIDIA_NIM_API_KEY=""
# OpenRouter Config
OPENROUTER_API_KEY=""
+# Diagnostic rollback only: "anthropic" (native /messages, default) | "openai" (chat completions)
+OPENROUTER_TRANSPORT="anthropic"
# DeepSeek Config
@@ -55,7 +57,7 @@ HTTP_CONNECT_TIMEOUT=2
ANTHROPIC_AUTH_TOKEN=
-# Messaging Platform: "telegram" | "discord"
+# Messaging Platform: "telegram" | "discord" | "none"
MESSAGING_PLATFORM="discord"
MESSAGING_RATE_LIMIT=1
MESSAGING_RATE_WINDOW=1
@@ -88,6 +90,7 @@ ALLOWED_DISCORD_CHANNELS=""
# Agent Config
CLAUDE_WORKSPACE="./agent_workspace"
ALLOWED_DIR=""
+CLAUDE_CLI_BIN="claude"
FAST_PREFIX_DETECTION=true
ENABLE_NETWORK_PROBE_MOCK=true
ENABLE_TITLE_GENERATION_SKIP=true
diff --git a/AGENTS.md b/AGENTS.md
index d43b17c..6b528d9 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -31,7 +31,7 @@
- **Performance**: Use list accumulation for strings (not `+=` in loops), cache env vars at init, prefer iterative over recursive when stack depth matters.
- **Platform-agnostic naming**: Use generic names (e.g. `PLATFORM_EDIT`) not platform-specific ones (e.g. `TELEGRAM_EDIT`) in shared code.
- **No type ignores**: Do not add `# type: ignore` or `# ty: ignore`. Fix the underlying type issue.
-- **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working.
+- **Complete migrations**: When moving modules, update imports to the new owner and remove old compatibility shims in the same change unless preserving a published interface is explicitly required.
## COGNITIVE WORKFLOW
diff --git a/PLAN.md b/PLAN.md
new file mode 100644
index 0000000..77aae0f
--- /dev/null
+++ b/PLAN.md
@@ -0,0 +1,88 @@
+# Architecture Plan
+
+This document is the baseline architecture guide referenced by `AGENTS.md` and
+`CLAUDE.md`. It records the intended dependency direction and the migration
+target for keeping the project modular as providers, clients, and smoke tests
+grow.
+
+## Current Product Shape
+
+`free-claude-code` is an Anthropic-compatible proxy with optional messaging
+workers:
+
+- `api/` owns the HTTP routes, request orchestration, model routing, auth, and
+ server lifecycle.
+- `providers/` owns upstream model adapters, request conversion, stream
+ conversion, provider rate limiting, and provider error mapping.
+- `messaging/` owns Discord and Telegram adapters, command handling, tree
+ threading, session persistence, transcript rendering, and voice intake.
+- `cli/` owns package entrypoints and managed Claude CLI subprocess sessions.
+- `config/` owns environment-backed settings and logging setup.
+- `smoke/` owns opt-in product smoke scenarios and the public coverage
+ inventory used by contract tests.
+
+## Intended Dependency Direction
+
+The repo should preserve this dependency order:
+
+```mermaid
+flowchart TD
+ config[config] --> api[api]
+ config --> providers[providers]
+ config --> messaging[messaging]
+ core[core.anthropic] --> api
+ core --> providers
+ core --> messaging
+ providers --> api
+ cli --> messaging
+ messaging --> api
+```
+
+The practical rule is simpler than the graph: shared protocol helpers belong in
+neutral core modules, not under a provider package. Provider adapters may depend
+on the neutral protocol layer, but API and messaging code should not import
+provider internals.
+
+## Target Boundaries
+
+- `core/anthropic/`: Anthropic protocol helpers, stream primitives, content
+ extraction, token estimation, user-facing error strings, and request
+ conversion utilities shared across API, providers, messaging, and tests.
+- `api/runtime.py`: application composition, optional messaging startup,
+ session store restoration, and cleanup ownership.
+- `providers/`: provider descriptors, credential resolution, transport
+ factories, scoped rate limiters, upstream request builders, and stream
+ transformers.
+- `messaging/`: platform-neutral orchestration split from command dispatch,
+ rendering, voice handling, and persistence.
+- `cli/`: typed Claude CLI runner config, subprocess management, and packaged
+ user-facing entrypoints.
+
+## Smoke Coverage Policy
+
+Default CI stays deterministic and runs `uv run pytest` against `tests/`.
+Product smoke lives under `smoke/` and is enabled with `FCC_LIVE_SMOKE=1`.
+Smoke runs should use `-n 0` unless a scenario is explicitly known to be safe
+under xdist.
+
+Live smoke has two valid skip classes:
+
+- `missing_env`: credentials, local services, binaries, or explicit opt-in flags
+ are absent.
+- `upstream_unavailable`: real providers, bot APIs, or local model servers are
+ unreachable.
+
+`product_failure` and `harness_bug` are regressions. When a provider is
+explicitly selected by `FCC_SMOKE_PROVIDER_MATRIX`, missing configuration should
+fail instead of being silently skipped.
+
+## Refactor Rules
+
+- Keep public request/response shapes stable while moving internals.
+- Complete module migrations in one change: update imports to the new owner and
+ remove old compatibility shims unless preserving a published interface is
+ explicitly required.
+- Lock behavior with focused tests before moving shared protocol or runtime
+ code.
+- Run checks in this order: `uv run ruff format`, `uv run ruff check`,
+ `uv run ty check`, `uv run pytest`.
diff --git a/api/app.py b/api/app.py
index cf5877c..f6677e3 100644
--- a/api/app.py
+++ b/api/app.py
@@ -1,6 +1,5 @@
"""FastAPI application factory and configuration."""
-import asyncio
import os
from contextlib import asynccontextmanager
@@ -14,6 +13,7 @@ from providers.exceptions import ProviderError
from .dependencies import cleanup_provider
from .routes import router
+from .runtime import AppRuntime, warn_if_process_auth_token
# Opt-in to future behavior for python-telegram-bot
os.environ["PTB_TIMEDELTA"] = "1"
@@ -23,174 +23,22 @@ _settings = get_settings()
configure_logging(_settings.log_file)
-_SHUTDOWN_TIMEOUT_S = 5.0
-
-
-async def _best_effort(
- name: str, awaitable, timeout_s: float = _SHUTDOWN_TIMEOUT_S
-) -> None:
- """Run a shutdown step with timeout; never raise to callers."""
- try:
- await asyncio.wait_for(awaitable, timeout=timeout_s)
- except TimeoutError:
- logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
- except Exception as e:
- logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
-
-
def _warn_if_process_auth_token(settings) -> None:
- """Warn when server auth was implicitly inherited from the shell."""
- uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None)
- if callable(uses_process_token) and uses_process_token():
- logger.warning(
- "ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
- "a configured .env file. The proxy will require that token. Add "
- "ANTHROPIC_AUTH_TOKEN= to .env to disable proxy auth, or set the "
- "same token in .env to make server auth explicit."
- )
+ """Compatibility wrapper for tests importing the old app helper."""
+ warn_if_process_auth_token(settings)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
- settings = get_settings()
- logger.info("Starting Claude Code Proxy...")
- _warn_if_process_auth_token(settings)
-
- # Initialize messaging platform if configured
- messaging_platform = None
- message_handler = None
- cli_manager = None
-
- try:
- # Use the messaging factory to create the right platform
- from messaging.platforms.factory import create_messaging_platform
-
- messaging_platform = create_messaging_platform(
- platform_type=settings.messaging_platform,
- bot_token=settings.telegram_bot_token,
- allowed_user_id=settings.allowed_telegram_user_id,
- discord_bot_token=settings.discord_bot_token,
- allowed_discord_channels=settings.allowed_discord_channels,
- )
-
- if messaging_platform:
- from cli.manager import CLISessionManager
- from messaging.handler import ClaudeMessageHandler
- from messaging.session import SessionStore
-
- # Setup workspace - CLI runs in allowed_dir if set (e.g. project root)
- workspace = (
- os.path.abspath(settings.allowed_dir)
- if settings.allowed_dir
- else os.getcwd()
- )
- os.makedirs(workspace, exist_ok=True)
-
- # Session data stored in agent_workspace
- data_path = os.path.abspath(settings.claude_workspace)
- os.makedirs(data_path, exist_ok=True)
-
- api_url = f"http://{settings.host}:{settings.port}/v1"
- allowed_dirs = [workspace] if settings.allowed_dir else []
- plans_dir_abs = os.path.abspath(
- os.path.join(settings.claude_workspace, "plans")
- )
- plans_directory = os.path.relpath(plans_dir_abs, workspace)
- cli_manager = CLISessionManager(
- workspace_path=workspace,
- api_url=api_url,
- allowed_dirs=allowed_dirs,
- plans_directory=plans_directory,
- )
-
- # Initialize session store
- session_store = SessionStore(
- storage_path=os.path.join(data_path, "sessions.json")
- )
-
- # Create and register message handler
- message_handler = ClaudeMessageHandler(
- platform=messaging_platform,
- cli_manager=cli_manager,
- session_store=session_store,
- )
-
- # Restore tree state if available
- saved_trees = session_store.get_all_trees()
- if saved_trees:
- logger.info(f"Restoring {len(saved_trees)} conversation trees...")
- from messaging.trees.queue_manager import TreeQueueManager
-
- message_handler.replace_tree_queue(
- TreeQueueManager.from_dict(
- {
- "trees": saved_trees,
- "node_to_tree": session_store.get_node_mapping(),
- },
- queue_update_callback=message_handler.update_queue_positions,
- node_started_callback=message_handler.mark_node_processing,
- )
- )
- # Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
- if message_handler.tree_queue.cleanup_stale_nodes() > 0:
- # Sync back and save
- tree_data = message_handler.tree_queue.to_dict()
- session_store.sync_from_tree_data(
- tree_data["trees"], tree_data["node_to_tree"]
- )
-
- # Wire up the handler
- messaging_platform.on_message(message_handler.handle_message)
-
- # Start the platform
- await messaging_platform.start()
- logger.info(
- f"{messaging_platform.name} platform started with message handler"
- )
-
- except ImportError as e:
- logger.warning(f"Messaging module import error: {e}")
- except Exception as e:
- logger.error(f"Failed to start messaging platform: {e}")
- import traceback
-
- logger.error(traceback.format_exc())
-
- # Store in app state for access in routes
- app.state.messaging_platform = messaging_platform
- app.state.message_handler = message_handler
- app.state.cli_manager = cli_manager
+ runtime = AppRuntime.for_app(
+ app, settings=get_settings(), provider_cleanup=cleanup_provider
+ )
+ await runtime.startup()
yield
- # Cleanup
- if message_handler and hasattr(message_handler, "session_store"):
- try:
- message_handler.session_store.flush_pending_save()
- except Exception as e:
- logger.warning(f"Session store flush on shutdown: {e}")
- logger.info("Shutdown requested, cleaning up...")
- if messaging_platform:
- await _best_effort("messaging_platform.stop", messaging_platform.stop())
- if cli_manager:
- await _best_effort("cli_manager.stop_all", cli_manager.stop_all())
- await _best_effort("cleanup_provider", cleanup_provider())
-
- # Ensure background limiter worker doesn't keep the loop alive.
- try:
- from messaging.limiter import MessagingRateLimiter
-
- await _best_effort(
- "MessagingRateLimiter.shutdown_instance",
- MessagingRateLimiter.shutdown_instance(),
- timeout_s=2.0,
- )
- except Exception:
- # Limiter may never have been imported/initialized.
- pass
-
- logger.info("Server shut down cleanly")
+ await runtime.shutdown()
def create_app() -> FastAPI:
diff --git a/api/dependencies.py b/api/dependencies.py
index bf54cd5..ede7588 100644
--- a/api/dependencies.py
+++ b/api/dependencies.py
@@ -5,8 +5,8 @@ from loguru import logger
from config.settings import Settings
from config.settings import get_settings as _get_settings
+from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider
-from providers.common import get_user_facing_error_message
from providers.exceptions import AuthenticationError
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
@@ -24,6 +24,7 @@ def get_provider_for_type(provider_type: str) -> BaseProvider:
Providers are cached in the registry and reused across requests.
"""
+ should_log_init = provider_type not in _providers
try:
provider = ProviderRegistry(_providers).get(provider_type, get_settings())
except AuthenticationError as e:
@@ -37,7 +38,7 @@ def get_provider_for_type(provider_type: str) -> BaseProvider:
", ".join(f"'{key}'" for key in PROVIDER_DESCRIPTORS),
)
raise
- if provider_type in _providers:
+ if should_log_init:
logger.info("Provider initialized: {}", provider_type)
return provider
diff --git a/api/detection.py b/api/detection.py
index 7df048a..9f3f33f 100644
--- a/api/detection.py
+++ b/api/detection.py
@@ -4,7 +4,7 @@ Detects quota checks, title generation, prefix detection, suggestion mode,
and filepath extraction requests to enable fast-path responses.
"""
-from providers.common.text import extract_text_from_content
+from core.anthropic import extract_text_from_content
from .models.anthropic import MessagesRequest
diff --git a/api/model_router.py b/api/model_router.py
index 8e755e1..789d02e 100644
--- a/api/model_router.py
+++ b/api/model_router.py
@@ -19,6 +19,18 @@ class ResolvedModel:
provider_model_ref: str
+@dataclass(frozen=True, slots=True)
+class RoutedMessagesRequest:
+ request: MessagesRequest
+ resolved: ResolvedModel
+
+
+@dataclass(frozen=True, slots=True)
+class RoutedTokenCountRequest:
+ request: TokenCountRequest
+ resolved: ResolvedModel
+
+
class ModelRouter:
"""Resolve incoming Claude model names to configured provider/model pairs."""
@@ -40,19 +52,21 @@ class ModelRouter:
provider_model_ref=provider_model_ref,
)
- def resolve_messages_request(self, request: MessagesRequest) -> MessagesRequest:
- """Return a routed copy of a MessagesRequest."""
- original_model = request.original_model or request.model
- resolved = self.resolve(original_model)
+ def resolve_messages_request(
+ self, request: MessagesRequest
+ ) -> RoutedMessagesRequest:
+ """Return an internal routed request context."""
+ resolved = self.resolve(request.model)
routed = request.model_copy(deep=True)
- routed.original_model = resolved.original_model
- routed.resolved_provider_model = resolved.provider_model_ref
routed.model = resolved.provider_model
- return routed
+ return RoutedMessagesRequest(request=routed, resolved=resolved)
def resolve_token_count_request(
self, request: TokenCountRequest
- ) -> TokenCountRequest:
- """Return a token-count request copy with provider model name applied."""
+ ) -> RoutedTokenCountRequest:
+ """Return an internal token-count request context."""
resolved = self.resolve(request.model)
- return request.model_copy(update={"model": resolved.provider_model}, deep=True)
+ routed = request.model_copy(
+ update={"model": resolved.provider_model}, deep=True
+ )
+ return RoutedTokenCountRequest(request=routed, resolved=resolved)
diff --git a/api/models/anthropic.py b/api/models/anthropic.py
index 9758295..a8bbef3 100644
--- a/api/models/anthropic.py
+++ b/api/models/anthropic.py
@@ -103,8 +103,6 @@ class MessagesRequest(BaseModel):
tool_choice: dict[str, Any] | None = None
thinking: ThinkingConfig | None = None
extra_body: dict[str, Any] | None = None
- original_model: str | None = None
- resolved_provider_model: str | None = None
class TokenCountRequest(BaseModel):
diff --git a/api/request_utils.py b/api/request_utils.py
index 2e8c154..1a6f149 100644
--- a/api/request_utils.py
+++ b/api/request_utils.py
@@ -1,101 +1,5 @@
-"""Request utility functions for API route handlers.
+"""Backward-compatible token counting import for API route handlers."""
-Contains token counting for API requests.
-"""
-
-import json
-
-import tiktoken
-from loguru import logger
-
-from providers.common import get_block_attr
-
-ENCODER = tiktoken.get_encoding("cl100k_base")
+from core.anthropic import get_token_count
__all__ = ["get_token_count"]
-
-
-def get_token_count(
- messages: list,
- system: str | list | None = None,
- tools: list | None = None,
-) -> int:
- """Estimate token count for a request.
-
- Uses tiktoken cl100k_base encoding to estimate token usage.
- Includes system prompt, messages, tools, and per-message overhead.
- """
- total_tokens = 0
-
- if system:
- if isinstance(system, str):
- total_tokens += len(ENCODER.encode(system))
- elif isinstance(system, list):
- for block in system:
- text = get_block_attr(block, "text", "")
- if text:
- total_tokens += len(ENCODER.encode(str(text)))
- total_tokens += 4 # System block formatting overhead
-
- for msg in messages:
- if isinstance(msg.content, str):
- total_tokens += len(ENCODER.encode(msg.content))
- elif isinstance(msg.content, list):
- for block in msg.content:
- b_type = get_block_attr(block, "type") or None
-
- if b_type == "text":
- text = get_block_attr(block, "text", "")
- total_tokens += len(ENCODER.encode(str(text)))
- elif b_type == "thinking":
- thinking = get_block_attr(block, "thinking", "")
- total_tokens += len(ENCODER.encode(str(thinking)))
- elif b_type == "tool_use":
- name = get_block_attr(block, "name", "")
- inp = get_block_attr(block, "input", {})
- block_id = get_block_attr(block, "id", "")
- total_tokens += len(ENCODER.encode(str(name)))
- total_tokens += len(ENCODER.encode(json.dumps(inp)))
- total_tokens += len(ENCODER.encode(str(block_id)))
- total_tokens += 15
- elif b_type == "image":
- source = get_block_attr(block, "source")
- if isinstance(source, dict):
- data = source.get("data") or source.get("base64") or ""
- if data:
- total_tokens += max(85, len(data) // 3000)
- else:
- total_tokens += 765
- else:
- total_tokens += 765
- elif b_type == "tool_result":
- content = get_block_attr(block, "content", "")
- tool_use_id = get_block_attr(block, "tool_use_id", "")
- if isinstance(content, str):
- total_tokens += len(ENCODER.encode(content))
- else:
- total_tokens += len(ENCODER.encode(json.dumps(content)))
- total_tokens += len(ENCODER.encode(str(tool_use_id)))
- total_tokens += 8
- else:
- logger.debug(
- "Unexpected block type %r, falling back to json/str encoding",
- b_type,
- )
- try:
- total_tokens += len(ENCODER.encode(json.dumps(block)))
- except TypeError, ValueError:
- total_tokens += len(ENCODER.encode(str(block)))
-
- if tools:
- for tool in tools:
- tool_str = (
- tool.name + (tool.description or "") + json.dumps(tool.input_schema)
- )
- total_tokens += len(ENCODER.encode(tool_str))
-
- total_tokens += len(messages) * 4
- if tools:
- total_tokens += len(tools) * 5
-
- return max(1, total_tokens)
diff --git a/api/runtime.py b/api/runtime.py
new file mode 100644
index 0000000..bb6d8c8
--- /dev/null
+++ b/api/runtime.py
@@ -0,0 +1,198 @@
+"""Application runtime composition and lifecycle ownership."""
+
+from __future__ import annotations
+
+import asyncio
+import os
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass
+from typing import Any
+
+from fastapi import FastAPI
+from loguru import logger
+
+from config.settings import Settings, get_settings
+
+from .dependencies import cleanup_provider
+
+_SHUTDOWN_TIMEOUT_S = 5.0
+
+
+async def best_effort(
+ name: str, awaitable: Any, timeout_s: float = _SHUTDOWN_TIMEOUT_S
+) -> None:
+ """Run a shutdown step with timeout; never raise to callers."""
+ try:
+ await asyncio.wait_for(awaitable, timeout=timeout_s)
+ except TimeoutError:
+ logger.warning(f"Shutdown step timed out: {name} ({timeout_s}s)")
+ except Exception as e:
+ logger.warning(f"Shutdown step failed: {name}: {type(e).__name__}: {e}")
+
+
+def warn_if_process_auth_token(settings: Settings) -> None:
+ """Warn when server auth was implicitly inherited from the shell."""
+ uses_process_token = getattr(settings, "uses_process_anthropic_auth_token", None)
+ if callable(uses_process_token) and uses_process_token():
+ logger.warning(
+ "ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
+ "a configured .env file. The proxy will require that token. Add "
+ "ANTHROPIC_AUTH_TOKEN= to .env to disable proxy auth, or set the "
+ "same token in .env to make server auth explicit."
+ )
+
+
+@dataclass(slots=True)
+class AppRuntime:
+ """Own optional messaging, CLI, session, and provider runtime resources."""
+
+ app: FastAPI
+ settings: Settings
+ provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider
+ messaging_platform: Any = None
+ message_handler: Any = None
+ cli_manager: Any = None
+
+ @classmethod
+ def for_app(
+ cls,
+ app: FastAPI,
+ settings: Settings | None = None,
+ provider_cleanup: Callable[[], Awaitable[None]] = cleanup_provider,
+ ) -> AppRuntime:
+ return cls(
+ app=app,
+ settings=settings or get_settings(),
+ provider_cleanup=provider_cleanup,
+ )
+
+ async def startup(self) -> None:
+ logger.info("Starting Claude Code Proxy...")
+ warn_if_process_auth_token(self.settings)
+ await self._start_messaging_if_configured()
+ self._publish_state()
+
+ async def shutdown(self) -> None:
+ if self.message_handler and hasattr(self.message_handler, "session_store"):
+ try:
+ self.message_handler.session_store.flush_pending_save()
+ except Exception as e:
+ logger.warning(f"Session store flush on shutdown: {e}")
+
+ logger.info("Shutdown requested, cleaning up...")
+ if self.messaging_platform:
+ await best_effort("messaging_platform.stop", self.messaging_platform.stop())
+ if self.cli_manager:
+ await best_effort("cli_manager.stop_all", self.cli_manager.stop_all())
+ await best_effort("cleanup_provider", self.provider_cleanup())
+ await self._shutdown_limiter()
+ logger.info("Server shut down cleanly")
+
+ async def _start_messaging_if_configured(self) -> None:
+ try:
+ from messaging.platforms.factory import create_messaging_platform
+
+ self.messaging_platform = create_messaging_platform(
+ platform_type=self.settings.messaging_platform,
+ bot_token=self.settings.telegram_bot_token,
+ allowed_user_id=self.settings.allowed_telegram_user_id,
+ discord_bot_token=self.settings.discord_bot_token,
+ allowed_discord_channels=self.settings.allowed_discord_channels,
+ )
+
+ if self.messaging_platform:
+ await self._start_message_handler()
+
+ except ImportError as e:
+ logger.warning(f"Messaging module import error: {e}")
+ except Exception as e:
+ logger.error(f"Failed to start messaging platform: {e}")
+ import traceback
+
+ logger.error(traceback.format_exc())
+
+ async def _start_message_handler(self) -> None:
+ from cli.manager import CLISessionManager
+ from messaging.handler import ClaudeMessageHandler
+ from messaging.session import SessionStore
+
+ workspace = (
+ os.path.abspath(self.settings.allowed_dir)
+ if self.settings.allowed_dir
+ else os.getcwd()
+ )
+ os.makedirs(workspace, exist_ok=True)
+
+ data_path = os.path.abspath(self.settings.claude_workspace)
+ os.makedirs(data_path, exist_ok=True)
+
+ api_url = f"http://{self.settings.host}:{self.settings.port}/v1"
+ allowed_dirs = [workspace] if self.settings.allowed_dir else []
+ plans_dir_abs = os.path.abspath(
+ os.path.join(self.settings.claude_workspace, "plans")
+ )
+ plans_directory = os.path.relpath(plans_dir_abs, workspace)
+ self.cli_manager = CLISessionManager(
+ workspace_path=workspace,
+ api_url=api_url,
+ allowed_dirs=allowed_dirs,
+ plans_directory=plans_directory,
+ claude_bin=getattr(self.settings, "claude_cli_bin", "claude"),
+ )
+
+ session_store = SessionStore(
+ storage_path=os.path.join(data_path, "sessions.json")
+ )
+ self.message_handler = ClaudeMessageHandler(
+ platform=self.messaging_platform,
+ cli_manager=self.cli_manager,
+ session_store=session_store,
+ )
+ self._restore_tree_state(session_store)
+
+ self.messaging_platform.on_message(self.message_handler.handle_message)
+ await self.messaging_platform.start()
+ logger.info(
+ f"{self.messaging_platform.name} platform started with message handler"
+ )
+
+ def _restore_tree_state(self, session_store: Any) -> None:
+ saved_trees = session_store.get_all_trees()
+ if not saved_trees:
+ return
+
+ logger.info(f"Restoring {len(saved_trees)} conversation trees...")
+ from messaging.trees.queue_manager import TreeQueueManager
+
+ self.message_handler.replace_tree_queue(
+ TreeQueueManager.from_dict(
+ {
+ "trees": saved_trees,
+ "node_to_tree": session_store.get_node_mapping(),
+ },
+ queue_update_callback=self.message_handler.update_queue_positions,
+ node_started_callback=self.message_handler.mark_node_processing,
+ )
+ )
+ if self.message_handler.tree_queue.cleanup_stale_nodes() > 0:
+ tree_data = self.message_handler.tree_queue.to_dict()
+ session_store.sync_from_tree_data(
+ tree_data["trees"], tree_data["node_to_tree"]
+ )
+
+ def _publish_state(self) -> None:
+ self.app.state.messaging_platform = self.messaging_platform
+ self.app.state.message_handler = self.message_handler
+ self.app.state.cli_manager = self.cli_manager
+
+ async def _shutdown_limiter(self) -> None:
+ try:
+ from messaging.limiter import MessagingRateLimiter
+
+ await best_effort(
+ "MessagingRateLimiter.shutdown_instance",
+ MessagingRateLimiter.shutdown_instance(),
+ timeout_s=2.0,
+ )
+ except Exception:
+ pass
diff --git a/api/services.py b/api/services.py
index 6138577..8705bf5 100644
--- a/api/services.py
+++ b/api/services.py
@@ -12,8 +12,8 @@ from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings
+from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider
-from providers.common import get_user_facing_error_message
from providers.exceptions import InvalidRequestError, ProviderError
from .model_router import ModelRouter
@@ -48,35 +48,32 @@ class ClaudeProxyService:
if not request_data.messages:
raise InvalidRequestError("messages cannot be empty")
- routed_request = self._model_router.resolve_messages_request(request_data)
+ routed = self._model_router.resolve_messages_request(request_data)
- optimized = try_optimizations(routed_request, self._settings)
+ optimized = try_optimizations(routed.request, self._settings)
if optimized is not None:
return optimized
logger.debug("No optimization matched, routing to provider")
- provider_type = (
- routed_request.resolved_provider_model or self._settings.model
- ).split("/", 1)[0]
- provider = self._provider_getter(provider_type)
+ provider = self._provider_getter(routed.resolved.provider_id)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST: request_id={} model={} messages={}",
request_id,
- routed_request.model,
- len(routed_request.messages),
+ routed.request.model,
+ len(routed.request.messages),
)
logger.debug(
- "FULL_PAYLOAD [{}]: {}", request_id, routed_request.model_dump()
+ "FULL_PAYLOAD [{}]: {}", request_id, routed.request.model_dump()
)
input_tokens = self._token_counter(
- routed_request.messages, routed_request.system, routed_request.tools
+ routed.request.messages, routed.request.system, routed.request.tools
)
return StreamingResponse(
provider.stream_response(
- routed_request,
+ routed.request,
input_tokens=input_tokens,
request_id=request_id,
),
@@ -102,17 +99,15 @@ class ClaudeProxyService:
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
- routed_request = self._model_router.resolve_token_count_request(
- request_data
- )
+ routed = self._model_router.resolve_token_count_request(request_data)
tokens = self._token_counter(
- routed_request.messages, routed_request.system, routed_request.tools
+ routed.request.messages, routed.request.system, routed.request.tools
)
logger.info(
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
request_id,
- routed_request.model,
- len(routed_request.messages),
+ routed.request.model,
+ len(routed.request.messages),
tokens,
)
return TokenCountResponse(input_tokens=tokens)
diff --git a/claude-pick b/claude-pick
index 792c6bd..409188b 100755
--- a/claude-pick
+++ b/claude-pick
@@ -13,11 +13,15 @@ OPENROUTER_MODELS_URL="https://openrouter.ai/api/v1/models"
DEFAULT_LM_STUDIO_BASE_URL="http://localhost:1234/v1"
DEFAULT_LLAMACPP_BASE_URL="http://localhost:8080/v1"
-if ! command -v python3 >/dev/null 2>&1; then
- echo "Error: python3 is required." >&2
+if ! command -v uv >/dev/null 2>&1; then
+ echo "Error: uv is required." >&2
exit 1
fi
+run_python() {
+ uv run python "$@"
+}
+
read_env_value() {
local key="$1"
[[ -f "$ENV_FILE" ]] || return 0
@@ -41,7 +45,7 @@ if ! command -v fzf >/dev/null 2>&1; then
fi
parse_models_from_json() {
- python3 -c '
+ run_python -c '
import json, sys
try:
payload = json.load(sys.stdin)
@@ -61,7 +65,7 @@ get_nvidia_models() {
exit 1
fi
- python3 -c '
+ run_python -c '
import json, sys
with open(sys.argv[1], "r", encoding="utf-8") as f:
payload = json.load(f)
diff --git a/cli/manager.py b/cli/manager.py
index 48d3849..b267633 100644
--- a/cli/manager.py
+++ b/cli/manager.py
@@ -28,6 +28,7 @@ class CLISessionManager:
api_url: str,
allowed_dirs: list[str] | None = None,
plans_directory: str | None = None,
+ claude_bin: str = "claude",
):
"""
Initialize the session manager.
@@ -42,6 +43,7 @@ class CLISessionManager:
self.api_url = api_url
self.allowed_dirs = allowed_dirs or []
self.plans_directory = plans_directory
+ self.claude_bin = claude_bin
self._sessions: dict[str, CLISession] = {}
self._pending_sessions: dict[str, CLISession] = {}
@@ -76,6 +78,7 @@ class CLISessionManager:
api_url=self.api_url,
allowed_dirs=self.allowed_dirs,
plans_directory=self.plans_directory,
+ claude_bin=self.claude_bin,
)
self._pending_sessions[temp_id] = new_session
logger.info(f"Created new session: {temp_id}")
diff --git a/cli/session.py b/cli/session.py
index 8847414..db18f6c 100644
--- a/cli/session.py
+++ b/cli/session.py
@@ -4,6 +4,7 @@ import asyncio
import json
import os
from collections.abc import AsyncGenerator
+from dataclasses import dataclass, field
from typing import Any
from loguru import logger
@@ -11,6 +12,17 @@ from loguru import logger
from .process_registry import register_pid, unregister_pid
+@dataclass(frozen=True, slots=True)
+class ClaudeCliConfig:
+ """Configuration for a managed Claude CLI subprocess."""
+
+ workspace_path: str
+ api_url: str
+ allowed_dirs: list[str] = field(default_factory=list)
+ plans_directory: str | None = None
+ claude_bin: str = "claude"
+
+
class CLISession:
"""Manages a single persistent Claude Code CLI subprocess."""
@@ -20,11 +32,20 @@ class CLISession:
api_url: str,
allowed_dirs: list[str] | None = None,
plans_directory: str | None = None,
+ claude_bin: str = "claude",
):
- self.workspace = os.path.normpath(os.path.abspath(workspace_path))
- self.api_url = api_url
- self.allowed_dirs = [os.path.normpath(d) for d in (allowed_dirs or [])]
- self.plans_directory = plans_directory
+ self.config = ClaudeCliConfig(
+ workspace_path=os.path.normpath(os.path.abspath(workspace_path)),
+ api_url=api_url,
+ allowed_dirs=[os.path.normpath(d) for d in (allowed_dirs or [])],
+ plans_directory=plans_directory,
+ claude_bin=claude_bin,
+ )
+ self.workspace = self.config.workspace_path
+ self.api_url = self.config.api_url
+ self.allowed_dirs = self.config.allowed_dirs
+ self.plans_directory = self.config.plans_directory
+ self.claude_bin = self.config.claude_bin
self.process: asyncio.subprocess.Process | None = None
self.current_session_id: str | None = None
self._is_busy = False
@@ -67,7 +88,7 @@ class CLISession:
# Build command
if session_id and not session_id.startswith("pending_"):
cmd = [
- "claude",
+ self.claude_bin,
"--resume",
session_id,
]
@@ -84,7 +105,7 @@ class CLISession:
logger.info(f"Resuming Claude session {session_id}")
else:
cmd = [
- "claude",
+ self.claude_bin,
"-p",
prompt,
"--output-format",
diff --git a/config/env.example b/config/env.example
index 37d64f5..3dd8593 100644
--- a/config/env.example
+++ b/config/env.example
@@ -16,6 +16,10 @@ DEEPSEEK_API_KEY=""
LM_STUDIO_BASE_URL="http://localhost:1234/v1"
+# Llama.cpp Config (local provider, no API key required)
+LLAMACPP_BASE_URL="http://localhost:8080/v1"
+
+
# All Claude model requests are mapped to these models, plain model is fallback
# Format: provider_type/model/name
# Valid providers: "nvidia_nim" | "open_router" | "deepseek" | "lmstudio" | "llamacpp"
@@ -25,7 +29,19 @@ MODEL_HAIKU=
MODEL="nvidia_nim/z-ai/glm4.7"
+# Thinking output
+# Global switch for provider reasoning requests and Claude thinking blocks.
+# Set false to suppress thinking across NIM, OpenRouter, LM Studio, and llama.cpp.
+ENABLE_THINKING=true
+
+
# Provider config
+# Per-provider proxy support: http and socks5, example: "http://username:password@host:port"
+NVIDIA_NIM_PROXY=""
+OPENROUTER_PROXY=""
+LMSTUDIO_PROXY=""
+LLAMACPP_PROXY=""
+
PROVIDER_RATE_LIMIT=40
PROVIDER_RATE_WINDOW=60
PROVIDER_MAX_CONCURRENCY=5
@@ -37,7 +53,11 @@ HTTP_WRITE_TIMEOUT=10
HTTP_CONNECT_TIMEOUT=2
-# Messaging Platform: "telegram" | "discord"
+# Optional server API key (Anthropic-style)
+ANTHROPIC_AUTH_TOKEN=
+
+
+# Messaging Platform: "telegram" | "discord" | "none"
MESSAGING_PLATFORM="discord"
MESSAGING_RATE_LIMIT=1
MESSAGING_RATE_WINDOW=1
@@ -70,6 +90,7 @@ ALLOWED_DISCORD_CHANNELS=""
# Agent Config
CLAUDE_WORKSPACE="./agent_workspace"
ALLOWED_DIR=""
+CLAUDE_CLI_BIN="claude"
FAST_PREFIX_DETECTION=true
ENABLE_NETWORK_PROBE_MOCK=true
ENABLE_TITLE_GENERATION_SKIP=true
diff --git a/config/settings.py b/config/settings.py
index b346f41..fbfb08a 100644
--- a/config/settings.py
+++ b/config/settings.py
@@ -99,7 +99,7 @@ class Settings(BaseSettings):
deepseek_api_key: str = Field(default="", validation_alias="DEEPSEEK_API_KEY")
# ==================== Messaging Platform Selection ====================
- # Valid: "telegram" | "discord"
+ # Valid: "telegram" | "discord" | "none"
messaging_platform: str = Field(
default="discord", validation_alias="MESSAGING_PLATFORM"
)
@@ -195,6 +195,7 @@ class Settings(BaseSettings):
)
claude_workspace: str = "./agent_workspace"
allowed_dir: str = ""
+ claude_cli_bin: str = Field(default="claude", validation_alias="CLAUDE_CLI_BIN")
# ==================== Server ====================
host: str = "0.0.0.0"
@@ -249,6 +250,15 @@ class Settings(BaseSettings):
)
return v
+ @field_validator("messaging_platform")
+ @classmethod
+ def validate_messaging_platform(cls, v: str) -> str:
+ if v not in ("telegram", "discord", "none"):
+ raise ValueError(
+ f"messaging_platform must be 'telegram', 'discord', or 'none', got {v!r}"
+ )
+ return v
+
@field_validator("model", "model_opus", "model_sonnet", "model_haiku")
@classmethod
def validate_model_format(cls, v: str | None) -> str | None:
diff --git a/core/__init__.py b/core/__init__.py
new file mode 100644
index 0000000..716557d
--- /dev/null
+++ b/core/__init__.py
@@ -0,0 +1 @@
+"""Neutral shared application core."""
diff --git a/core/anthropic/__init__.py b/core/anthropic/__init__.py
new file mode 100644
index 0000000..9af5987
--- /dev/null
+++ b/core/anthropic/__init__.py
@@ -0,0 +1,29 @@
+"""Anthropic protocol helpers shared across API, providers, and integrations."""
+
+from .content import extract_text_from_content, get_block_attr, get_block_type
+from .conversion import AnthropicToOpenAIConverter, build_base_request_body
+from .errors import append_request_id, get_user_facing_error_message
+from .sse import ContentBlockManager, SSEBuilder, map_stop_reason
+from .thinking import ContentChunk, ContentType, ThinkTagParser
+from .tokens import get_token_count
+from .tools import HeuristicToolParser
+from .utils import set_if_not_none
+
+__all__ = [
+ "AnthropicToOpenAIConverter",
+ "ContentBlockManager",
+ "ContentChunk",
+ "ContentType",
+ "HeuristicToolParser",
+ "SSEBuilder",
+ "ThinkTagParser",
+ "append_request_id",
+ "build_base_request_body",
+ "extract_text_from_content",
+ "get_block_attr",
+ "get_block_type",
+ "get_token_count",
+ "get_user_facing_error_message",
+ "map_stop_reason",
+ "set_if_not_none",
+]
diff --git a/core/anthropic/content.py b/core/anthropic/content.py
new file mode 100644
index 0000000..e16aaaf
--- /dev/null
+++ b/core/anthropic/content.py
@@ -0,0 +1,31 @@
+"""Content block helpers for Anthropic-compatible payloads."""
+
+from typing import Any
+
+
+def get_block_attr(block: Any, attr: str, default: Any = None) -> Any:
+ """Get an attribute from a Pydantic model, lightweight object, or dict."""
+ if hasattr(block, attr):
+ return getattr(block, attr)
+ if isinstance(block, dict):
+ return block.get(attr, default)
+ return default
+
+
+def get_block_type(block: Any) -> str | None:
+ """Return a content block type when present."""
+ return get_block_attr(block, "type")
+
+
+def extract_text_from_content(content: Any) -> str:
+ """Extract concatenated text from message content."""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts: list[str] = []
+ for block in content:
+ text = get_block_attr(block, "text", "")
+ if isinstance(text, str) and text:
+ parts.append(text)
+ return "".join(parts)
+ return ""
diff --git a/providers/common/message_converter.py b/core/anthropic/conversion.py
similarity index 82%
rename from providers/common/message_converter.py
rename to core/anthropic/conversion.py
index 5e2184f..444df77 100644
--- a/providers/common/message_converter.py
+++ b/core/anthropic/conversion.py
@@ -3,23 +3,12 @@
import json
from typing import Any
-
-def get_block_attr(block: Any, attr: str, default: Any = None) -> Any:
- """Get attribute from object or dict."""
- if hasattr(block, attr):
- return getattr(block, attr)
- if isinstance(block, dict):
- return block.get(attr, default)
- return default
-
-
-def get_block_type(block: Any) -> str | None:
- """Get block type from object or dict."""
- return get_block_attr(block, "type")
+from .content import get_block_attr, get_block_type
+from .utils import set_if_not_none
class AnthropicToOpenAIConverter:
- """Converts Anthropic message format to OpenAI format."""
+ """Convert Anthropic message format to OpenAI-compatible format."""
@staticmethod
def convert_messages(
@@ -29,12 +18,6 @@ class AnthropicToOpenAIConverter:
include_reasoning_for_openrouter: bool = False,
include_reasoning_content: bool = False,
) -> list[dict[str, Any]]:
- """Convert a list of Anthropic messages to OpenAI format.
-
- When reasoning_content preservation is enabled, assistant messages with
- thinking blocks get reasoning_content added for provider multi-turn
- reasoning continuation.
- """
result = []
for msg in messages:
@@ -70,7 +53,6 @@ class AnthropicToOpenAIConverter:
include_reasoning_for_openrouter: bool = False,
include_reasoning_content: bool = False,
) -> list[dict[str, Any]]:
- """Convert assistant message blocks, preserving interleaved thinking+text order."""
content_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
@@ -106,9 +88,6 @@ class AnthropicToOpenAIConverter:
)
content_str = "\n\n".join(content_parts)
-
- # Ensure content is never an empty string for assistant messages
- # NIM (especially Mistral models) requires non-empty content if there are no tool calls
if not content_str and not tool_calls:
content_str = " "
@@ -125,7 +104,6 @@ class AnthropicToOpenAIConverter:
@staticmethod
def _convert_user_message(content: list[Any]) -> list[dict[str, Any]]:
- """Convert user message blocks (including tool results), preserving order."""
result: list[dict[str, Any]] = []
text_parts: list[str] = []
@@ -162,7 +140,6 @@ class AnthropicToOpenAIConverter:
@staticmethod
def convert_tools(tools: list[Any]) -> list[dict[str, Any]]:
- """Convert Anthropic tools to OpenAI format."""
return [
{
"type": "function",
@@ -177,7 +154,6 @@ class AnthropicToOpenAIConverter:
@staticmethod
def convert_tool_choice(tool_choice: Any) -> Any:
- """Convert Anthropic tool_choice to OpenAI-compatible format."""
if not isinstance(tool_choice, dict):
return tool_choice
@@ -197,10 +173,9 @@ class AnthropicToOpenAIConverter:
@staticmethod
def convert_system_prompt(system: Any) -> dict[str, str] | None:
- """Convert Anthropic system prompt to OpenAI format."""
if isinstance(system, str):
return {"role": "system", "content": system}
- elif isinstance(system, list):
+ if isinstance(system, list):
text_parts = [
get_block_attr(block, "text", "")
for block in system
@@ -219,14 +194,7 @@ def build_base_request_body(
include_reasoning_for_openrouter: bool = False,
include_reasoning_content: bool = False,
) -> dict[str, Any]:
- """Build the common parts of an OpenAI-format request body.
-
- Handles message conversion, system prompt, max_tokens, temperature,
- top_p, stop sequences, tools, and tool_choice. Provider-specific
- parameters (extra_body, penalties, NIM settings) are added by callers.
- """
- from providers.common.utils import set_if_not_none
-
+ """Build the common parts of an OpenAI-format request body."""
messages = AnthropicToOpenAIConverter.convert_messages(
request_data.messages,
include_thinking=include_thinking,
diff --git a/core/anthropic/errors.py b/core/anthropic/errors.py
new file mode 100644
index 0000000..38e3d12
--- /dev/null
+++ b/core/anthropic/errors.py
@@ -0,0 +1,53 @@
+"""User-facing error formatting shared by API, providers, and integrations."""
+
+import httpx
+import openai
+
+
+def get_user_facing_error_message(
+ e: Exception,
+ *,
+ read_timeout_s: float | None = None,
+) -> str:
+ """Return a readable, non-empty error message for users."""
+ message = str(e).strip()
+ if message:
+ return message
+
+ if isinstance(e, httpx.ReadTimeout):
+ if read_timeout_s is not None:
+ return f"Provider request timed out after {read_timeout_s:g}s."
+ return "Provider request timed out."
+ if isinstance(e, httpx.ConnectTimeout):
+ return "Could not connect to provider."
+ if isinstance(e, TimeoutError):
+ if read_timeout_s is not None:
+ return f"Provider request timed out after {read_timeout_s:g}s."
+ return "Request timed out."
+
+ name = type(e).__name__
+ status_code = getattr(e, "status_code", None)
+ if isinstance(e, openai.RateLimitError) or name == "RateLimitError":
+ return "Provider rate limit reached. Please retry shortly."
+ if isinstance(e, openai.AuthenticationError) or name == "AuthenticationError":
+ return "Provider authentication failed. Check API key."
+ if isinstance(e, openai.BadRequestError) or name == "InvalidRequestError":
+ return "Invalid request sent to provider."
+ if name == "OverloadedError":
+ return "Provider is currently overloaded. Please retry."
+ if name == "APIError":
+ if status_code in (502, 503, 504):
+ return "Provider is temporarily unavailable. Please retry."
+ return "Provider API request failed."
+ if name.endswith("ProviderError") or name == "ProviderError":
+ return "Provider request failed."
+
+ return "Provider request failed unexpectedly."
+
+
+def append_request_id(message: str, request_id: str | None) -> str:
+ """Append request_id suffix when available."""
+ base = message.strip() or "Provider request failed unexpectedly."
+ if request_id:
+ return f"{base} (request_id={request_id})"
+ return base
diff --git a/providers/common/sse_builder.py b/core/anthropic/sse.py
similarity index 80%
rename from providers/common/sse_builder.py
rename to core/anthropic/sse.py
index 9dd2183..7998839 100644
--- a/providers/common/sse_builder.py
+++ b/core/anthropic/sse.py
@@ -3,7 +3,6 @@
import json
from collections.abc import Iterator
from dataclasses import dataclass, field
-from typing import Any
from loguru import logger
@@ -15,7 +14,6 @@ except Exception:
ENCODER = None
-# Map OpenAI finish_reason to Anthropic stop_reason
STOP_REASON_MAP = {
"stop": "end_turn",
"length": "max_tokens",
@@ -35,7 +33,7 @@ def map_stop_reason(openai_reason: str | None) -> str:
class ToolCallState:
"""State for a single streaming tool call."""
- block_index: int # -1 if not yet allocated
+ block_index: int
tool_id: str
name: str
contents: list[str] = field(default_factory=list)
@@ -46,7 +44,7 @@ class ToolCallState:
@dataclass
class ContentBlockManager:
- """Manages content block indices and state."""
+ """Manage content block indices and state."""
next_index: int = 0
thinking_index: int = -1
@@ -56,17 +54,11 @@ class ContentBlockManager:
tool_states: dict[int, ToolCallState] = field(default_factory=dict)
def allocate_index(self) -> int:
- """Allocate and return the next block index."""
idx = self.next_index
self.next_index += 1
return idx
def register_tool_name(self, index: int, name: str) -> None:
- """Register or merge a streaming tool name fragment.
-
- Handles providers that stream names as fragments and those that
- resend the full name on every chunk.
- """
if index not in self.tool_states:
self.tool_states[index] = ToolCallState(
block_index=-1, tool_id="", name=name
@@ -80,11 +72,6 @@ class ContentBlockManager:
state.name = prev + name
def buffer_task_args(self, index: int, args: str) -> dict | None:
- """Buffer Task tool args and return parsed JSON when complete.
-
- Returns the parsed (and patched) args dict once the buffer forms
- valid JSON, or None if still accumulating.
- """
state = self.tool_states.get(index)
if state is None or state.task_args_emitted:
return None
@@ -103,7 +90,6 @@ class ContentBlockManager:
return args_json
def flush_task_arg_buffers(self) -> list[tuple[int, str]]:
- """Flush any remaining Task arg buffers. Returns (tool_index, json_str) pairs."""
results: list[tuple[int, str]] = []
for tool_index, state in list(self.tool_states.items()):
if not state.task_arg_buffer or state.task_args_emitted:
@@ -142,15 +128,12 @@ class SSEBuilder:
self._accumulated_text_parts: list[str] = []
self._accumulated_reasoning_parts: list[str] = []
- def _format_event(self, event_type: str, data: dict[str, Any]) -> str:
- """Format as SSE string."""
+ def _format_event(self, event_type: str, data: dict) -> str:
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip())
return event_str
- # Message lifecycle events
def message_start(self) -> str:
- """Generate message_start event."""
usage = {"input_tokens": self.input_tokens, "output_tokens": 1}
return self._format_event(
"message_start",
@@ -170,7 +153,6 @@ class SSEBuilder:
)
def message_delta(self, stop_reason: str, output_tokens: int) -> str:
- """Generate message_delta event with stop reason."""
return self._format_event(
"message_delta",
{
@@ -184,13 +166,10 @@ class SSEBuilder:
)
def message_stop(self) -> str:
- """Generate message_stop event."""
return self._format_event("message_stop", {"type": "message_stop"})
- # Content block events
def content_block_start(self, index: int, block_type: str, **kwargs) -> str:
- """Generate content_block_start event."""
- content_block: dict[str, Any] = {"type": block_type}
+ content_block: dict = {"type": block_type}
if block_type == "thinking":
content_block["thinking"] = kwargs.get("thinking", "")
elif block_type == "text":
@@ -210,8 +189,7 @@ class SSEBuilder:
)
def content_block_delta(self, index: int, delta_type: str, content: str) -> str:
- """Generate content_block_delta event."""
- delta: dict[str, Any] = {"type": delta_type}
+ delta: dict = {"type": delta_type}
if delta_type == "thinking_delta":
delta["thinking"] = content
elif delta_type == "text_delta":
@@ -229,7 +207,6 @@ class SSEBuilder:
)
def content_block_stop(self, index: int) -> str:
- """Generate content_block_stop event."""
return self._format_event(
"content_block_stop",
{
@@ -238,45 +215,35 @@ class SSEBuilder:
},
)
- # High-level helpers for thinking blocks
def start_thinking_block(self) -> str:
- """Start a thinking block, allocating index."""
self.blocks.thinking_index = self.blocks.allocate_index()
self.blocks.thinking_started = True
return self.content_block_start(self.blocks.thinking_index, "thinking")
def emit_thinking_delta(self, content: str) -> str:
- """Emit thinking content delta."""
self._accumulated_reasoning_parts.append(content)
return self.content_block_delta(
self.blocks.thinking_index, "thinking_delta", content
)
def stop_thinking_block(self) -> str:
- """Stop the current thinking block."""
self.blocks.thinking_started = False
return self.content_block_stop(self.blocks.thinking_index)
- # High-level helpers for text blocks
def start_text_block(self) -> str:
- """Start a text block, allocating index."""
self.blocks.text_index = self.blocks.allocate_index()
self.blocks.text_started = True
return self.content_block_start(self.blocks.text_index, "text")
def emit_text_delta(self, content: str) -> str:
- """Emit text content delta."""
self._accumulated_text_parts.append(content)
return self.content_block_delta(self.blocks.text_index, "text_delta", content)
def stop_text_block(self) -> str:
- """Stop the current text block."""
self.blocks.text_started = False
return self.content_block_stop(self.blocks.text_index)
- # High-level helpers for tool blocks
def start_tool_block(self, tool_index: int, tool_id: str, name: str) -> str:
- """Start a tool_use block."""
block_idx = self.blocks.allocate_index()
if tool_index in self.blocks.tool_states:
state = self.blocks.tool_states[tool_index]
@@ -293,7 +260,6 @@ class SSEBuilder:
return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name)
def emit_tool_delta(self, tool_index: int, partial_json: str) -> str:
- """Emit tool input delta."""
state = self.blocks.tool_states[tool_index]
state.contents.append(partial_json)
return self.content_block_delta(
@@ -301,34 +267,28 @@ class SSEBuilder:
)
def stop_tool_block(self, tool_index: int) -> str:
- """Stop a tool block."""
block_idx = self.blocks.tool_states[tool_index].block_index
return self.content_block_stop(block_idx)
- # State management helpers
def ensure_thinking_block(self) -> Iterator[str]:
- """Ensure a thinking block is started, closing text block if needed."""
if self.blocks.text_started:
yield self.stop_text_block()
if not self.blocks.thinking_started:
yield self.start_thinking_block()
def ensure_text_block(self) -> Iterator[str]:
- """Ensure a text block is started, closing thinking block if needed."""
if self.blocks.thinking_started:
yield self.stop_thinking_block()
if not self.blocks.text_started:
yield self.start_text_block()
def close_content_blocks(self) -> Iterator[str]:
- """Close thinking and text blocks (before tool calls)."""
if self.blocks.thinking_started:
yield self.stop_thinking_block()
if self.blocks.text_started:
yield self.stop_text_block()
def close_all_blocks(self) -> Iterator[str]:
- """Close all open blocks (thinking, text, tools)."""
if self.blocks.thinking_started:
yield self.stop_thinking_block()
if self.blocks.text_started:
@@ -337,54 +297,45 @@ class SSEBuilder:
if state.started:
yield self.stop_tool_block(tool_index)
- # Error handling
def emit_error(self, error_message: str) -> Iterator[str]:
- """Emit an error as a text block."""
error_index = self.blocks.allocate_index()
yield self.content_block_start(error_index, "text")
yield self.content_block_delta(error_index, "text_delta", error_message)
yield self.content_block_stop(error_index)
- # Accumulated content access
@property
def accumulated_text(self) -> str:
- """Get accumulated text content."""
return "".join(self._accumulated_text_parts)
@property
def accumulated_reasoning(self) -> str:
- """Get accumulated reasoning content."""
return "".join(self._accumulated_reasoning_parts)
def estimate_output_tokens(self) -> int:
- """Estimate output tokens from accumulated content."""
accumulated_text = self.accumulated_text
accumulated_reasoning = self.accumulated_reasoning
if ENCODER:
text_tokens = len(ENCODER.encode(accumulated_text))
reasoning_tokens = len(ENCODER.encode(accumulated_reasoning))
- # Tool calls are harder to tokenize exactly without reconstruction, but we can approximate
- # by tokenizing the json dumps of tool contents
tool_tokens = 0
started_tool_count = 0
for state in self.blocks.tool_states.values():
tool_tokens += len(ENCODER.encode(state.name))
tool_tokens += len(ENCODER.encode("".join(state.contents)))
- tool_tokens += 15 # Control tokens overhead per tool
+ tool_tokens += 15
if state.started:
started_tool_count += 1
- # Per-block overhead (~4 tokens per content block)
block_count = (
(1 if accumulated_reasoning else 0)
+ (1 if accumulated_text else 0)
+ started_tool_count
)
- block_overhead = block_count * 4
-
- return text_tokens + reasoning_tokens + tool_tokens + block_overhead
+ return text_tokens + reasoning_tokens + tool_tokens + (block_count * 4)
text_tokens = len(accumulated_text) // 4
reasoning_tokens = len(accumulated_reasoning) // 4
- tool_tokens = sum(1 for s in self.blocks.tool_states.values() if s.started) * 50
+ tool_tokens = (
+ sum(1 for state in self.blocks.tool_states.values() if state.started) * 50
+ )
return text_tokens + reasoning_tokens + tool_tokens
diff --git a/providers/common/think_parser.py b/core/anthropic/thinking.py
similarity index 60%
rename from providers/common/think_parser.py
rename to core/anthropic/thinking.py
index 2deed99..fa793d1 100644
--- a/providers/common/think_parser.py
+++ b/core/anthropic/thinking.py
@@ -1,4 +1,4 @@
-"""Think tag parser for extracting reasoning content from responses."""
+"""Streaming parser for provider-emitted thinking tags."""
from collections.abc import Iterator
from dataclasses import dataclass
@@ -22,15 +22,13 @@ class ContentChunk:
class ThinkTagParser:
"""
- Streaming parser for ... tags.
+ Streaming parser for ``...`` tags.
Handles partial tags at chunk boundaries by buffering.
"""
OPEN_TAG = ""
CLOSE_TAG = ""
- OPEN_TAG_LEN = 7
- CLOSE_TAG_LEN = 8
def __init__(self):
self._buffer: str = ""
@@ -42,13 +40,7 @@ class ThinkTagParser:
return self._in_think_tag
def feed(self, content: str) -> Iterator[ContentChunk]:
- """
- Feed content and yield parsed chunks.
-
- Handles partial tags by buffering content near potential tag boundaries.
- Uses an iterative loop instead of mutual recursion to avoid stack overflow
- on inputs with many consecutive think tags.
- """
+ """Feed content and yield parsed chunks."""
self._buffer += content
while self._buffer:
@@ -61,7 +53,6 @@ class ThinkTagParser:
if chunk:
yield chunk
elif len(self._buffer) == prev_len:
- # No progress: waiting for more data
break
def _parse_outside_think(self) -> ContentChunk | None:
@@ -69,29 +60,23 @@ class ThinkTagParser:
think_start = self._buffer.find(self.OPEN_TAG)
orphan_close = self._buffer.find(self.CLOSE_TAG)
- # Handle orphan - strip it (Step Fun AI sends reasoning via
- # reasoning_content but may leak closing tags in content)
if orphan_close != -1 and (think_start == -1 or orphan_close < think_start):
pre_orphan = self._buffer[:orphan_close]
- self._buffer = self._buffer[orphan_close + self.CLOSE_TAG_LEN :]
+ self._buffer = self._buffer[orphan_close + len(self.CLOSE_TAG) :]
if pre_orphan:
return ContentChunk(ContentType.TEXT, pre_orphan)
- # Buffer shrunk; the feed() loop will continue parsing
return None
if think_start == -1:
- # No tag found - check for partial tag at end
- # We buffer any trailing '<' and subsequent characters that could be part of or
last_bracket = self._buffer.rfind("<")
if last_bracket != -1:
potential_tag = self._buffer[last_bracket:]
tag_len = len(potential_tag)
- # Check if could be partial or
if (
- tag_len < self.OPEN_TAG_LEN
+ tag_len < len(self.OPEN_TAG)
and self.OPEN_TAG.startswith(potential_tag)
) or (
- tag_len < self.CLOSE_TAG_LEN
+ tag_len < len(self.CLOSE_TAG)
and self.CLOSE_TAG.startswith(potential_tag)
):
emit = self._buffer[:last_bracket]
@@ -100,35 +85,28 @@ class ThinkTagParser:
return ContentChunk(ContentType.TEXT, emit)
return None
- # No partial tag found or it's irrelevant
emit = self._buffer
self._buffer = ""
if emit:
return ContentChunk(ContentType.TEXT, emit)
return None
- else:
- # Found tag
- pre_think = self._buffer[:think_start]
- self._buffer = self._buffer[think_start + self.OPEN_TAG_LEN :]
- self._in_think_tag = True
- if pre_think:
- return ContentChunk(ContentType.TEXT, pre_think)
- # Buffer shrunk (consumed ); the feed() loop will continue
- # parsing inside the think tag on the next iteration
- return None
+
+ pre_think = self._buffer[:think_start]
+ self._buffer = self._buffer[think_start + len(self.OPEN_TAG) :]
+ self._in_think_tag = True
+ if pre_think:
+ return ContentChunk(ContentType.TEXT, pre_think)
+ return None
def _parse_inside_think(self) -> ContentChunk | None:
"""Parse content inside think tags."""
think_end = self._buffer.find(self.CLOSE_TAG)
if think_end == -1:
- # No closing tag - check for partial at end
last_bracket = self._buffer.rfind("<")
- if (
- last_bracket != -1
- and len(self._buffer) - last_bracket < self.CLOSE_TAG_LEN
+ if last_bracket != -1 and len(self._buffer) - last_bracket < len(
+ self.CLOSE_TAG
):
- # Check if the partial string could be the start of
potential_tag = self._buffer[last_bracket:]
if self.CLOSE_TAG.startswith(potential_tag):
emit = self._buffer[:last_bracket]
@@ -142,16 +120,13 @@ class ThinkTagParser:
if emit:
return ContentChunk(ContentType.THINKING, emit)
return None
- else:
- # Found tag
- thinking_content = self._buffer[:think_end]
- self._buffer = self._buffer[think_end + self.CLOSE_TAG_LEN :]
- self._in_think_tag = False
- if thinking_content:
- return ContentChunk(ContentType.THINKING, thinking_content)
- # Buffer shrunk (consumed ); the feed() loop will continue
- # parsing outside the think tag on the next iteration
- return None
+
+ thinking_content = self._buffer[:think_end]
+ self._buffer = self._buffer[think_end + len(self.CLOSE_TAG) :]
+ self._in_think_tag = False
+ if thinking_content:
+ return ContentChunk(ContentType.THINKING, thinking_content)
+ return None
def flush(self) -> ContentChunk | None:
"""Flush any remaining buffered content."""
diff --git a/core/anthropic/tokens.py b/core/anthropic/tokens.py
new file mode 100644
index 0000000..dfa1115
--- /dev/null
+++ b/core/anthropic/tokens.py
@@ -0,0 +1,92 @@
+"""Token estimation for Anthropic-compatible requests."""
+
+import json
+
+import tiktoken
+from loguru import logger
+
+from .content import get_block_attr
+
+ENCODER = tiktoken.get_encoding("cl100k_base")
+
+
+def get_token_count(
+ messages: list,
+ system: str | list | None = None,
+ tools: list | None = None,
+) -> int:
+ """Estimate token count for a request."""
+ total_tokens = 0
+
+ if system:
+ if isinstance(system, str):
+ total_tokens += len(ENCODER.encode(system))
+ elif isinstance(system, list):
+ for block in system:
+ text = get_block_attr(block, "text", "")
+ if text:
+ total_tokens += len(ENCODER.encode(str(text)))
+ total_tokens += 4
+
+ for msg in messages:
+ if isinstance(msg.content, str):
+ total_tokens += len(ENCODER.encode(msg.content))
+ elif isinstance(msg.content, list):
+ for block in msg.content:
+ b_type = get_block_attr(block, "type") or None
+
+ if b_type == "text":
+ text = get_block_attr(block, "text", "")
+ total_tokens += len(ENCODER.encode(str(text)))
+ elif b_type == "thinking":
+ thinking = get_block_attr(block, "thinking", "")
+ total_tokens += len(ENCODER.encode(str(thinking)))
+ elif b_type == "tool_use":
+ name = get_block_attr(block, "name", "")
+ inp = get_block_attr(block, "input", {})
+ block_id = get_block_attr(block, "id", "")
+ total_tokens += len(ENCODER.encode(str(name)))
+ total_tokens += len(ENCODER.encode(json.dumps(inp)))
+ total_tokens += len(ENCODER.encode(str(block_id)))
+ total_tokens += 15
+ elif b_type == "image":
+ source = get_block_attr(block, "source")
+ if isinstance(source, dict):
+ data = source.get("data") or source.get("base64") or ""
+ if data:
+ total_tokens += max(85, len(data) // 3000)
+ else:
+ total_tokens += 765
+ else:
+ total_tokens += 765
+ elif b_type == "tool_result":
+ content = get_block_attr(block, "content", "")
+ tool_use_id = get_block_attr(block, "tool_use_id", "")
+ if isinstance(content, str):
+ total_tokens += len(ENCODER.encode(content))
+ else:
+ total_tokens += len(ENCODER.encode(json.dumps(content)))
+ total_tokens += len(ENCODER.encode(str(tool_use_id)))
+ total_tokens += 8
+ else:
+ logger.debug(
+ "Unexpected block type %r, falling back to json/str encoding",
+ b_type,
+ )
+ try:
+ total_tokens += len(ENCODER.encode(json.dumps(block)))
+ except TypeError, ValueError:
+ total_tokens += len(ENCODER.encode(str(block)))
+
+ if tools:
+ for tool in tools:
+ tool_str = (
+ tool.name + (tool.description or "") + json.dumps(tool.input_schema)
+ )
+ total_tokens += len(ENCODER.encode(tool_str))
+
+ total_tokens += len(messages) * 4
+ if tools:
+ total_tokens += len(tools) * 5
+
+ return max(1, total_tokens)
diff --git a/providers/common/heuristic_tool_parser.py b/core/anthropic/tools.py
similarity index 59%
rename from providers/common/heuristic_tool_parser.py
rename to core/anthropic/tools.py
index e338db5..0751d9d 100644
--- a/providers/common/heuristic_tool_parser.py
+++ b/core/anthropic/tools.py
@@ -1,3 +1,5 @@
+"""Heuristic parser for text-emitted tool calls."""
+
import re
import uuid
from enum import Enum
@@ -5,9 +7,6 @@ from typing import Any
from loguru import logger
-# Some OpenAI-compatible backends/models occasionally leak internal sentinel tokens
-# into `delta.content` (e.g. "<|tool_call_end|>"). These should never be shown to
-# end users, and they can disrupt downstream parsing if left in place.
_CONTROL_TOKEN_RE = re.compile(r"<\|[^|>]{1,80}\|>")
_CONTROL_TOKEN_START = "<|"
_CONTROL_TOKEN_END = "|>"
@@ -21,14 +20,13 @@ class ParserState(Enum):
class HeuristicToolParser:
"""
- Stateful parser that detects raw text tool calls in the format:
- ● value...
+ Stateful parser for raw text tool calls.
- This is used as a fallback for models that emit tool calls as text
- instead of using the structured API.
+ Some OpenAI-compatible models emit tool calls as text rather than structured
+ chunks. This parser converts the common ``● `` form into
+ Anthropic-style ``tool_use`` blocks.
"""
- # Class-level compiled patterns (compiled once, not per instance)
_FUNC_START_PATTERN = re.compile(r"●\s*]+)>")
_PARAM_PATTERN = re.compile(
r"]+)>(.*?)(?:|$)", re.DOTALL
@@ -42,17 +40,9 @@ class HeuristicToolParser:
self._current_parameters = {}
def _strip_control_tokens(self, text: str) -> str:
- # Remove complete sentinel tokens. If a token is split across chunks it
- # will be removed once the buffer contains the full token.
return _CONTROL_TOKEN_RE.sub("", text)
def _split_incomplete_control_token_tail(self) -> str:
- """
- If the buffer ends with an incomplete "<|...|>" sentinel token, keep that
- fragment in the buffer and return the safe-to-emit prefix.
-
- This prevents leaking raw sentinel fragments to the user when streaming.
- """
start = self._buffer.rfind(_CONTROL_TOKEN_START)
if start == -1:
return ""
@@ -65,13 +55,7 @@ class HeuristicToolParser:
return prefix
def feed(self, text: str) -> tuple[str, list[dict[str, Any]]]:
- """
- Feed text into the parser.
- Returns a tuple of (filtered_text, detected_tool_calls).
-
- filtered_text: Text that should be passed through as normal message content.
- detected_tools: List of Anthropic-format tool_use blocks.
- """
+ """Feed text and return safe text plus detected tool calls."""
self._buffer += text
self._buffer = self._strip_control_tokens(self._buffer)
detected_tools = []
@@ -79,15 +63,12 @@ class HeuristicToolParser:
while True:
if self._state == ParserState.TEXT:
- # Look for the trigger character
if "●" in self._buffer:
idx = self._buffer.find("●")
filtered_output_parts.append(self._buffer[:idx])
self._buffer = self._buffer[idx:]
self._state = ParserState.MATCHING_FUNCTION
else:
- # Avoid emitting an incomplete "<|...|>" sentinel fragment if the
- # token got split across streaming chunks.
safe_prefix = self._split_incomplete_control_token_tail()
if safe_prefix:
filtered_output_parts.append(safe_prefix)
@@ -98,46 +79,30 @@ class HeuristicToolParser:
break
if self._state == ParserState.MATCHING_FUNCTION:
- # We need enough buffer to match the function tag
- # e.g. "● "
match = self._FUNC_START_PATTERN.search(self._buffer)
if match:
self._current_function_name = match.group(1).strip()
self._current_tool_id = f"toolu_heuristic_{uuid.uuid4().hex[:8]}"
self._current_parameters = {}
-
- # Consume the function start from buffer
self._buffer = self._buffer[match.end() :]
self._state = ParserState.PARSING_PARAMETERS
logger.debug(
"Heuristic bypass: Detected start of tool call '{}'",
self._current_function_name,
)
+ elif len(self._buffer) > 100:
+ filtered_output_parts.append(self._buffer[0])
+ self._buffer = self._buffer[1:]
+ self._state = ParserState.TEXT
else:
- # If we have "●" but not the full tag yet, wait for more data
- # Unless the buffer has grown too large without a match
- if len(self._buffer) > 100:
- # Probably not a tool call, treat as text
- filtered_output_parts.append(self._buffer[0])
- self._buffer = self._buffer[1:]
- self._state = ParserState.TEXT
- else:
- break
+ break
if self._state == ParserState.PARSING_PARAMETERS:
- # Look for parameters. We look for to know a param is complete.
- # Or wait for another " in param_match.group(0):
- # Detect any content before the parameter match and preserve it
pre_match_text = self._buffer[: param_match.start()]
if pre_match_text:
filtered_output_parts.append(pre_match_text)
@@ -149,32 +114,19 @@ class HeuristicToolParser:
else:
break
- # Heuristic for completion:
- # 1. We have at least one param and we see a character that doesn't belong to the format
- # 2. Significant pause (not handled here, handled by caller via flush if needed)
- # 3. Another ● character (start of NEXT tool call)
-
if "●" in self._buffer:
- # Next tool call starting or something else, close current
- # But first, capture any text before the ●
idx = self._buffer.find("●")
if idx > 0:
filtered_output_parts.append(self._buffer[:idx])
self._buffer = self._buffer[idx:]
finished_tool_call = True
elif len(self._buffer) > 0 and not self._buffer.strip().startswith("<"):
- # We have text that doesn't look like a tag, and we already parsed some or are in param state
- # Let's see if we have trailing param starts
if " list[dict[str, Any]]:
- """
- Flush any remaining tool calls in the buffer.
- """
+ """Flush any remaining tool call in the buffer."""
self._buffer = self._strip_control_tokens(self._buffer)
detected_tools = []
if self._state == ParserState.PARSING_PARAMETERS:
- # Try to extract any partial parameters remaining in buffer
- # Even without
partial_matches = re.finditer(
r"]+)>(.*)$", self._buffer, re.DOTALL
)
- for m in partial_matches:
- key = m.group(1).strip()
- val = m.group(2).strip()
+ for match in partial_matches:
+ key = match.group(1).strip()
+ val = match.group(2).strip()
self._current_parameters[key] = val
detected_tools.append(
diff --git a/providers/common/utils.py b/core/anthropic/utils.py
similarity index 55%
rename from providers/common/utils.py
rename to core/anthropic/utils.py
index 7f290ea..84d33de 100644
--- a/providers/common/utils.py
+++ b/core/anthropic/utils.py
@@ -1,9 +1,9 @@
-"""Shared utility helpers for provider request builders."""
+"""Small shared protocol utility helpers."""
from typing import Any
def set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
- """Set body[key] = value only when value is not None."""
+ """Set ``body[key]`` only when value is not None."""
if value is not None:
body[key] = value
diff --git a/messaging/command_dispatcher.py b/messaging/command_dispatcher.py
new file mode 100644
index 0000000..b1188b9
--- /dev/null
+++ b/messaging/command_dispatcher.py
@@ -0,0 +1,48 @@
+"""Command parsing and dispatch for messaging handlers."""
+
+from __future__ import annotations
+
+from collections.abc import Awaitable, Callable
+from typing import Protocol
+
+from .models import IncomingMessage
+
+CommandHandler = Callable[[IncomingMessage], Awaitable[None]]
+
+
+class SupportsMessagingCommands(Protocol):
+ async def _handle_clear_command(self, incoming: IncomingMessage) -> None: ...
+
+ async def _handle_stop_command(self, incoming: IncomingMessage) -> None: ...
+
+ async def _handle_stats_command(self, incoming: IncomingMessage) -> None: ...
+
+
+def parse_command_base(text: str | None) -> str:
+ """Return the slash command without bot mention suffix."""
+ parts = (text or "").strip().split()
+ cmd = parts[0] if parts else ""
+ return cmd.split("@", 1)[0] if cmd else ""
+
+
+def message_kind_for_command(command_base: str) -> str:
+ """Return the persistence kind for an incoming message."""
+ return "command" if command_base.startswith("/") else "content"
+
+
+async def dispatch_command(
+ handler: SupportsMessagingCommands,
+ incoming: IncomingMessage,
+ command_base: str,
+) -> bool:
+ """Dispatch a known command and return whether it was handled."""
+ commands: dict[str, CommandHandler] = {
+ "/clear": handler._handle_clear_command,
+ "/stop": handler._handle_stop_command,
+ "/stats": handler._handle_stats_command,
+ }
+ command = commands.get(command_base)
+ if command is None:
+ return False
+ await command(incoming)
+ return True
diff --git a/messaging/handler.py b/messaging/handler.py
index be617a8..52290e0 100644
--- a/messaging/handler.py
+++ b/messaging/handler.py
@@ -12,8 +12,13 @@ import time
from loguru import logger
-from providers.common import get_user_facing_error_message
+from core.anthropic import get_user_facing_error_message
+from .command_dispatcher import (
+ dispatch_command,
+ message_kind_for_command,
+ parse_command_base,
+)
from .commands import (
handle_clear_command,
handle_stats_command,
@@ -155,36 +160,23 @@ class ClaudeMessageHandler:
async def _handle_message_impl(self, incoming: IncomingMessage) -> None:
"""Implementation of handle_message with context bound."""
- # Check for commands
- parts = (incoming.text or "").strip().split()
- cmd = parts[0] if parts else ""
- cmd_base = cmd.split("@", 1)[0] if cmd else ""
+ cmd_base = parse_command_base(incoming.text)
# Record incoming message ID for best-effort UI clearing (/clear), even if
# we later ignore this message (status/command/etc).
try:
if incoming.message_id is not None:
- kind = "command" if cmd_base.startswith("/") else "content"
self.session_store.record_message_id(
incoming.platform,
incoming.chat_id,
str(incoming.message_id),
direction="in",
- kind=kind,
+ kind=message_kind_for_command(cmd_base),
)
except Exception as e:
logger.debug(f"Failed to record incoming message_id: {e}")
- if cmd_base == "/clear":
- await self._handle_clear_command(incoming)
- return
-
- if cmd_base == "/stop":
- await self._handle_stop_command(incoming)
- return
-
- if cmd_base == "/stats":
- await self._handle_stats_command(incoming)
+ if await dispatch_command(self, incoming, cmd_base):
return
# Filter out status messages (our own messages)
diff --git a/messaging/platforms/discord.py b/messaging/platforms/discord.py
index 97661c9..3222c20 100644
--- a/messaging/platforms/discord.py
+++ b/messaging/platforms/discord.py
@@ -14,11 +14,11 @@ from typing import Any, cast
from loguru import logger
-from providers.common import get_user_facing_error_message
+from core.anthropic import get_user_facing_error_message
from ..models import IncomingMessage
from ..rendering.discord_markdown import format_status_discord
-from ..voice import PendingVoiceRegistry, VoiceTranscriptionService
+from ..voice_pipeline import VoiceNotePipeline
from .base import MessagingPlatform
AUDIO_EXTENSIONS = (".ogg", ".mp4", ".mp3", ".wav", ".m4a")
@@ -116,8 +116,7 @@ class DiscordPlatform(MessagingPlatform):
self._connected = False
self._limiter: Any | None = None
self._start_task: asyncio.Task | None = None
- self._pending_voice = PendingVoiceRegistry()
- self._voice_transcription = VoiceTranscriptionService()
+ self._voice_pipeline = VoiceNotePipeline()
async def _handle_client_message(self, message: Any) -> None:
"""Adapter entry point used by the internal discord client."""
@@ -127,17 +126,19 @@ class DiscordPlatform(MessagingPlatform):
self, chat_id: str, voice_msg_id: str, status_msg_id: str
) -> None:
"""Register a voice note as pending transcription."""
- await self._pending_voice.register(chat_id, voice_msg_id, status_msg_id)
+ await self._voice_pipeline.register_pending(
+ chat_id, voice_msg_id, status_msg_id
+ )
async def cancel_pending_voice(
self, chat_id: str, reply_id: str
) -> tuple[str, str] | None:
"""Cancel a pending voice transcription. Returns (voice_msg_id, status_msg_id) if found."""
- return await self._pending_voice.cancel(chat_id, reply_id)
+ return await self._voice_pipeline.cancel_pending(chat_id, reply_id)
async def _is_voice_still_pending(self, chat_id: str, voice_msg_id: str) -> bool:
"""Check if a voice note is still pending (not cancelled)."""
- return await self._pending_voice.is_pending(chat_id, voice_msg_id)
+ return await self._voice_pipeline.is_pending(chat_id, voice_msg_id)
def _get_audio_attachment(self, message: Any) -> Any | None:
"""Return first audio attachment, or None."""
@@ -198,7 +199,7 @@ class DiscordPlatform(MessagingPlatform):
try:
await attachment.save(str(tmp_path))
- transcribed = await self._voice_transcription.transcribe(
+ transcribed = await self._voice_pipeline.transcribe(
tmp_path,
ct,
whisper_model=settings.whisper_model,
@@ -209,7 +210,7 @@ class DiscordPlatform(MessagingPlatform):
await self.queue_delete_message(channel_id, str(status_msg_id))
return True
- await self._pending_voice.complete(
+ await self._voice_pipeline.complete(
channel_id, message_id, str(status_msg_id)
)
diff --git a/messaging/platforms/factory.py b/messaging/platforms/factory.py
index f260c3b..4b4bf27 100644
--- a/messaging/platforms/factory.py
+++ b/messaging/platforms/factory.py
@@ -24,6 +24,10 @@ def create_messaging_platform(
Returns:
Configured MessagingPlatform instance, or None if not configured.
"""
+ if platform_type == "none":
+ logger.info("Messaging platform disabled by configuration")
+ return None
+
if platform_type == "telegram":
bot_token = kwargs.get("bot_token")
if not bot_token:
diff --git a/messaging/platforms/telegram.py b/messaging/platforms/telegram.py
index 9eadc4c..602b622 100644
--- a/messaging/platforms/telegram.py
+++ b/messaging/platforms/telegram.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any
from loguru import logger
-from providers.common import get_user_facing_error_message
+from core.anthropic import get_user_facing_error_message
if TYPE_CHECKING:
from telegram import Update
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from ..models import IncomingMessage
from ..rendering.telegram_markdown import escape_md_v2, format_status
-from ..voice import PendingVoiceRegistry, VoiceTranscriptionService
+from ..voice_pipeline import VoiceNotePipeline
from .base import MessagingPlatform
# Optional import - python-telegram-bot may not be installed
@@ -83,24 +83,25 @@ class TelegramPlatform(MessagingPlatform):
self._connected = False
self._limiter: Any | None = None # Will be MessagingRateLimiter
# Pending voice transcriptions: (chat_id, msg_id) -> (voice_msg_id, status_msg_id)
- self._pending_voice = PendingVoiceRegistry()
- self._voice_transcription = VoiceTranscriptionService()
+ self._voice_pipeline = VoiceNotePipeline()
async def _register_pending_voice(
self, chat_id: str, voice_msg_id: str, status_msg_id: str
) -> None:
"""Register a voice note as pending transcription (for /clear reply during transcription)."""
- await self._pending_voice.register(chat_id, voice_msg_id, status_msg_id)
+ await self._voice_pipeline.register_pending(
+ chat_id, voice_msg_id, status_msg_id
+ )
async def cancel_pending_voice(
self, chat_id: str, reply_id: str
) -> tuple[str, str] | None:
"""Cancel a pending voice transcription. Returns (voice_msg_id, status_msg_id) if found."""
- return await self._pending_voice.cancel(chat_id, reply_id)
+ return await self._voice_pipeline.cancel_pending(chat_id, reply_id)
async def _is_voice_still_pending(self, chat_id: str, voice_msg_id: str) -> bool:
"""Check if a voice note is still pending (not cancelled)."""
- return await self._pending_voice.is_pending(chat_id, voice_msg_id)
+ return await self._voice_pipeline.is_pending(chat_id, voice_msg_id)
async def start(self) -> None:
"""Initialize and connect to Telegram."""
@@ -597,7 +598,7 @@ class TelegramPlatform(MessagingPlatform):
tg_file = await context.bot.get_file(voice.file_id)
await tg_file.download_to_drive(custom_path=str(tmp_path))
- transcribed = await self._voice_transcription.transcribe(
+ transcribed = await self._voice_pipeline.transcribe(
tmp_path,
voice.mime_type or "audio/ogg",
whisper_model=settings.whisper_model,
@@ -608,7 +609,7 @@ class TelegramPlatform(MessagingPlatform):
await self.queue_delete_message(chat_id, str(status_msg_id))
return
- await self._pending_voice.complete(chat_id, message_id, str(status_msg_id))
+ await self._voice_pipeline.complete(chat_id, message_id, str(status_msg_id))
incoming = IncomingMessage(
text=transcribed,
diff --git a/messaging/trees/processor.py b/messaging/trees/processor.py
index d9255b4..5fdcca6 100644
--- a/messaging/trees/processor.py
+++ b/messaging/trees/processor.py
@@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
from loguru import logger
-from providers.common import get_user_facing_error_message
+from core.anthropic import get_user_facing_error_message
from .data import MessageNode, MessageState, MessageTree
diff --git a/messaging/trees/queue_manager.py b/messaging/trees/queue_manager.py
index 14b0f10..fbf7029 100644
--- a/messaging/trees/queue_manager.py
+++ b/messaging/trees/queue_manager.py
@@ -14,14 +14,6 @@ from .data import MessageNode, MessageState, MessageTree
from .processor import TreeQueueProcessor
from .repository import TreeRepository
-# Backward compatibility: re-export moved classes
-__all__ = [
- "MessageNode",
- "MessageState",
- "MessageTree",
- "TreeQueueManager",
-]
-
class TreeQueueManager:
"""
diff --git a/messaging/voice_pipeline.py b/messaging/voice_pipeline.py
new file mode 100644
index 0000000..4c4fc85
--- /dev/null
+++ b/messaging/voice_pipeline.py
@@ -0,0 +1,53 @@
+"""Platform-neutral voice note pipeline."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from .voice import PendingVoiceRegistry, VoiceTranscriptionService
+
+
+class VoiceNotePipeline:
+ """Coordinate pending voice state and transcription backend execution."""
+
+ def __init__(
+ self,
+ *,
+ registry: PendingVoiceRegistry | None = None,
+ transcription: VoiceTranscriptionService | None = None,
+ ) -> None:
+ self._registry = registry or PendingVoiceRegistry()
+ self._transcription = transcription or VoiceTranscriptionService()
+
+ async def register_pending(
+ self, chat_id: str, voice_msg_id: str, status_msg_id: str
+ ) -> None:
+ await self._registry.register(chat_id, voice_msg_id, status_msg_id)
+
+ async def cancel_pending(
+ self, chat_id: str, reply_id: str
+ ) -> tuple[str, str] | None:
+ return await self._registry.cancel(chat_id, reply_id)
+
+ async def is_pending(self, chat_id: str, voice_msg_id: str) -> bool:
+ return await self._registry.is_pending(chat_id, voice_msg_id)
+
+ async def complete(
+ self, chat_id: str, voice_msg_id: str, status_msg_id: str
+ ) -> None:
+ await self._registry.complete(chat_id, voice_msg_id, status_msg_id)
+
+ async def transcribe(
+ self,
+ file_path: Path,
+ mime_type: str,
+ *,
+ whisper_model: str,
+ whisper_device: str,
+ ) -> str:
+ return await self._transcription.transcribe(
+ file_path,
+ mime_type,
+ whisper_model=whisper_model,
+ whisper_device=whisper_device,
+ )
diff --git a/providers/anthropic_messages.py b/providers/anthropic_messages.py
index acb0db5..a884a03 100644
--- a/providers/anthropic_messages.py
+++ b/providers/anthropic_messages.py
@@ -9,8 +9,9 @@ from typing import Any, Literal
import httpx
from loguru import logger
+from core.anthropic import get_user_facing_error_message
from providers.base import BaseProvider, ProviderConfig
-from providers.common import get_user_facing_error_message, map_error
+from providers.error_mapping import map_error
from providers.rate_limit import GlobalRateLimiter
ANTHROPIC_DEFAULT_MAX_TOKENS = 81920
diff --git a/providers/common/__init__.py b/providers/common/__init__.py
deleted file mode 100644
index 4e6c128..0000000
--- a/providers/common/__init__.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Shared provider utilities used by NIM, OpenRouter, and LM Studio."""
-
-from .error_mapping import append_request_id, get_user_facing_error_message, map_error
-from .heuristic_tool_parser import HeuristicToolParser
-from .message_converter import (
- AnthropicToOpenAIConverter,
- build_base_request_body,
- get_block_attr,
- get_block_type,
-)
-from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason
-from .think_parser import ContentChunk, ContentType, ThinkTagParser
-from .utils import set_if_not_none
-
-__all__ = [
- "AnthropicToOpenAIConverter",
- "ContentBlockManager",
- "ContentChunk",
- "ContentType",
- "HeuristicToolParser",
- "SSEBuilder",
- "ThinkTagParser",
- "append_request_id",
- "build_base_request_body",
- "get_block_attr",
- "get_block_type",
- "get_user_facing_error_message",
- "map_error",
- "map_stop_reason",
- "set_if_not_none",
-]
diff --git a/providers/common/error_mapping.py b/providers/common/error_mapping.py
deleted file mode 100644
index d0c0d06..0000000
--- a/providers/common/error_mapping.py
+++ /dev/null
@@ -1,106 +0,0 @@
-"""Error mapping for OpenAI-compatible providers (NIM, OpenRouter, LM Studio)."""
-
-import httpx
-import openai
-
-from providers.exceptions import (
- APIError,
- AuthenticationError,
- InvalidRequestError,
- OverloadedError,
- ProviderError,
- RateLimitError,
-)
-from providers.rate_limit import GlobalRateLimiter
-
-
-def get_user_facing_error_message(
- e: Exception,
- *,
- read_timeout_s: float | None = None,
-) -> str:
- """Return a readable, non-empty error message for users."""
- message = str(e).strip()
- if message:
- return message
-
- if isinstance(e, httpx.ReadTimeout):
- if read_timeout_s is not None:
- return f"Provider request timed out after {read_timeout_s:g}s."
- return "Provider request timed out."
- if isinstance(e, httpx.ConnectTimeout):
- return "Could not connect to provider."
- if isinstance(e, TimeoutError):
- if read_timeout_s is not None:
- return f"Provider request timed out after {read_timeout_s:g}s."
- return "Request timed out."
-
- if isinstance(e, (RateLimitError, openai.RateLimitError)):
- return "Provider rate limit reached. Please retry shortly."
- if isinstance(e, (AuthenticationError, openai.AuthenticationError)):
- return "Provider authentication failed. Check API key."
- if isinstance(e, (InvalidRequestError, openai.BadRequestError)):
- return "Invalid request sent to provider."
- if isinstance(e, OverloadedError):
- return "Provider is currently overloaded. Please retry."
- if isinstance(e, APIError):
- if e.status_code in (502, 503, 504):
- return "Provider is temporarily unavailable. Please retry."
- return "Provider API request failed."
- if isinstance(e, ProviderError):
- return "Provider request failed."
-
- return "Provider request failed unexpectedly."
-
-
-def append_request_id(message: str, request_id: str | None) -> str:
- """Append request_id suffix when available."""
- base = message.strip() or "Provider request failed unexpectedly."
- if request_id:
- return f"{base} (request_id={request_id})"
- return base
-
-
-def map_error(
- e: Exception, *, rate_limiter: GlobalRateLimiter | None = None
-) -> Exception:
- """Map OpenAI or HTTPX exception to specific ProviderError."""
- message = get_user_facing_error_message(e)
- limiter = rate_limiter or GlobalRateLimiter.get_instance()
-
- # Map OpenAI Specific Errors
- if isinstance(e, openai.AuthenticationError):
- return AuthenticationError(message, raw_error=str(e))
- if isinstance(e, openai.RateLimitError):
- # Trigger global rate limit block
- limiter.set_blocked(60) # Default 60s cooldown
- return RateLimitError(message, raw_error=str(e))
- if isinstance(e, openai.BadRequestError):
- return InvalidRequestError(message, raw_error=str(e))
- if isinstance(e, openai.InternalServerError):
- raw_message = str(e)
- if "overloaded" in raw_message.lower() or "capacity" in raw_message.lower():
- return OverloadedError(message, raw_error=raw_message)
- return APIError(message, status_code=500, raw_error=str(e))
- if isinstance(e, openai.APIError):
- return APIError(
- message, status_code=getattr(e, "status_code", 500), raw_error=str(e)
- )
-
- # Map raw HTTPX Errors
- if isinstance(e, httpx.HTTPStatusError):
- status = e.response.status_code
- if status in (401, 403):
- return AuthenticationError(message, raw_error=str(e))
- if status == 429:
- limiter.set_blocked(60)
- return RateLimitError(message, raw_error=str(e))
- if status == 400:
- return InvalidRequestError(message, raw_error=str(e))
- if status >= 500:
- if status in (502, 503, 504):
- return OverloadedError(message, raw_error=str(e))
- return APIError(message, status_code=status, raw_error=str(e))
- return APIError(message, status_code=status, raw_error=str(e))
-
- return e
diff --git a/providers/common/text.py b/providers/common/text.py
deleted file mode 100644
index ce89627..0000000
--- a/providers/common/text.py
+++ /dev/null
@@ -1,17 +0,0 @@
-"""Shared text extraction utilities."""
-
-from typing import Any
-
-
-def extract_text_from_content(content: Any) -> str:
- """Extract concatenated text from message content (str or list of content blocks)."""
- if isinstance(content, str):
- return content
- if isinstance(content, list):
- parts = []
- for block in content:
- text = getattr(block, "text", "")
- if text and isinstance(text, str):
- parts.append(text)
- return "".join(parts)
- return ""
diff --git a/providers/deepseek/request.py b/providers/deepseek/request.py
index e274be7..67cabf9 100644
--- a/providers/deepseek/request.py
+++ b/providers/deepseek/request.py
@@ -4,7 +4,7 @@ from typing import Any
from loguru import logger
-from providers.common.message_converter import build_base_request_body
+from core.anthropic import build_base_request_body
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
diff --git a/providers/error_mapping.py b/providers/error_mapping.py
new file mode 100644
index 0000000..3826a84
--- /dev/null
+++ b/providers/error_mapping.py
@@ -0,0 +1,56 @@
+"""Provider-specific exception mapping."""
+
+import httpx
+import openai
+
+from core.anthropic import get_user_facing_error_message
+from providers.exceptions import (
+ APIError,
+ AuthenticationError,
+ InvalidRequestError,
+ OverloadedError,
+ RateLimitError,
+)
+from providers.rate_limit import GlobalRateLimiter
+
+
+def map_error(
+ e: Exception, *, rate_limiter: GlobalRateLimiter | None = None
+) -> Exception:
+ """Map OpenAI or HTTPX exception to specific ProviderError."""
+ message = get_user_facing_error_message(e)
+ limiter = rate_limiter or GlobalRateLimiter.get_instance()
+
+ if isinstance(e, openai.AuthenticationError):
+ return AuthenticationError(message, raw_error=str(e))
+ if isinstance(e, openai.RateLimitError):
+ limiter.set_blocked(60)
+ return RateLimitError(message, raw_error=str(e))
+ if isinstance(e, openai.BadRequestError):
+ return InvalidRequestError(message, raw_error=str(e))
+ if isinstance(e, openai.InternalServerError):
+ raw_message = str(e)
+ if "overloaded" in raw_message.lower() or "capacity" in raw_message.lower():
+ return OverloadedError(message, raw_error=raw_message)
+ return APIError(message, status_code=500, raw_error=str(e))
+ if isinstance(e, openai.APIError):
+ return APIError(
+ message, status_code=getattr(e, "status_code", 500), raw_error=str(e)
+ )
+
+ if isinstance(e, httpx.HTTPStatusError):
+ status = e.response.status_code
+ if status in (401, 403):
+ return AuthenticationError(message, raw_error=str(e))
+ if status == 429:
+ limiter.set_blocked(60)
+ return RateLimitError(message, raw_error=str(e))
+ if status == 400:
+ return InvalidRequestError(message, raw_error=str(e))
+ if status >= 500:
+ if status in (502, 503, 504):
+ return OverloadedError(message, raw_error=str(e))
+ return APIError(message, status_code=status, raw_error=str(e))
+ return APIError(message, status_code=status, raw_error=str(e))
+
+ return e
diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py
index 0c2f26d..c0a83f3 100644
--- a/providers/nvidia_nim/request.py
+++ b/providers/nvidia_nim/request.py
@@ -6,8 +6,7 @@ from typing import Any
from loguru import logger
from config.nim import NimSettings
-from providers.common.message_converter import build_base_request_body
-from providers.common.utils import set_if_not_none
+from core.anthropic import build_base_request_body, set_if_not_none
def _set_extra(
diff --git a/providers/open_router/chat_request.py b/providers/open_router/chat_request.py
index 3e9d2ba..41359be 100644
--- a/providers/open_router/chat_request.py
+++ b/providers/open_router/chat_request.py
@@ -4,7 +4,7 @@ from typing import Any
from loguru import logger
-from providers.common.message_converter import build_base_request_body
+from core.anthropic import build_base_request_body
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
diff --git a/providers/open_router/client.py b/providers/open_router/client.py
index d6e46a7..2572874 100644
--- a/providers/open_router/client.py
+++ b/providers/open_router/client.py
@@ -8,9 +8,9 @@ from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
+from core.anthropic import SSEBuilder, append_request_id
from providers.anthropic_messages import AnthropicMessagesTransport, StreamChunkMode
from providers.base import ProviderConfig
-from providers.common import SSEBuilder, append_request_id
from providers.openai_compat import OpenAIChatTransport
from .chat_request import build_chat_request_body
diff --git a/providers/openai_compat.py b/providers/openai_compat.py
index 4125d84..6ca1079 100644
--- a/providers/openai_compat.py
+++ b/providers/openai_compat.py
@@ -10,17 +10,17 @@ import httpx
from loguru import logger
from openai import AsyncOpenAI
-from providers.base import BaseProvider, ProviderConfig
-from providers.common import (
+from core.anthropic import (
ContentType,
HeuristicToolParser,
SSEBuilder,
ThinkTagParser,
append_request_id,
get_user_facing_error_message,
- map_error,
map_stop_reason,
)
+from providers.base import BaseProvider, ProviderConfig
+from providers.error_mapping import map_error
from providers.rate_limit import GlobalRateLimiter
diff --git a/providers/rate_limit.py b/providers/rate_limit.py
index 7a523f6..0e427e9 100644
--- a/providers/rate_limit.py
+++ b/providers/rate_limit.py
@@ -50,6 +50,7 @@ class GlobalRateLimiter:
self._rate_limit = rate_limit
self._rate_window = float(rate_window)
+ self._max_concurrency = max_concurrency
# Monotonic timestamps of the last granted slots.
self._request_times: deque[float] = deque()
self._blocked_until: float = 0
@@ -95,10 +96,21 @@ class GlobalRateLimiter:
"""Get or create a provider-scoped limiter instance."""
if not scope:
raise ValueError("scope must be non-empty")
- if scope not in cls._scoped_instances:
+ desired_rate_limit = rate_limit or 40
+ desired_rate_window = float(rate_window or 60.0)
+ existing = cls._scoped_instances.get(scope)
+ if existing and existing.matches_config(
+ desired_rate_limit, desired_rate_window, max_concurrency
+ ):
+ return existing
+ if existing:
+ logger.info(
+ "Rebuilding provider rate limiter for updated scope '{}'", scope
+ )
+ if scope not in cls._scoped_instances or existing:
cls._scoped_instances[scope] = cls(
- rate_limit=rate_limit or 40,
- rate_window=rate_window or 60.0,
+ rate_limit=desired_rate_limit,
+ rate_window=desired_rate_window,
max_concurrency=max_concurrency,
)
return cls._scoped_instances[scope]
@@ -174,6 +186,16 @@ class GlobalRateLimiter:
"""Check if currently reactively blocked."""
return time.monotonic() < self._blocked_until
+ def matches_config(
+ self, rate_limit: int, rate_window: float, max_concurrency: int
+ ) -> bool:
+ """Return whether this limiter matches the requested runtime config."""
+ return (
+ self._rate_limit == rate_limit
+ and self._rate_window == float(rate_window)
+ and self._max_concurrency == max_concurrency
+ )
+
def remaining_wait(self) -> float:
"""Get remaining reactive wait time in seconds."""
return max(0.0, self._blocked_until - time.monotonic())
diff --git a/providers/registry.py b/providers/registry.py
index 8a63a3f..a0284c7 100644
--- a/providers/registry.py
+++ b/providers/registry.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from collections.abc import MutableMapping
+from collections.abc import Callable, MutableMapping
from dataclasses import dataclass
from typing import Literal
@@ -20,6 +20,7 @@ from providers.open_router import (
)
TransportType = Literal["openai_chat", "anthropic_messages"]
+ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
@dataclass(frozen=True, slots=True)
@@ -80,6 +81,37 @@ PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = {
}
+def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
+ return NvidiaNimProvider(config, nim_settings=settings.nim)
+
+
+def _create_open_router(config: ProviderConfig, settings: Settings) -> BaseProvider:
+ if settings.openrouter_transport == "openai":
+ return OpenRouterChatProvider(config)
+ return OpenRouterProvider(config)
+
+
+def _create_deepseek(config: ProviderConfig, settings: Settings) -> BaseProvider:
+ return DeepSeekProvider(config)
+
+
+def _create_lmstudio(config: ProviderConfig, settings: Settings) -> BaseProvider:
+ return LMStudioProvider(config)
+
+
+def _create_llamacpp(config: ProviderConfig, settings: Settings) -> BaseProvider:
+ return LlamaCppProvider(config)
+
+
+PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
+ "nvidia_nim": _create_nvidia_nim,
+ "open_router": _create_open_router,
+ "deepseek": _create_deepseek,
+ "lmstudio": _create_lmstudio,
+ "llamacpp": _create_llamacpp,
+}
+
+
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
if attr_name is None:
return default
@@ -144,20 +176,10 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
)
config = build_provider_config(descriptor, settings)
- if provider_id == "nvidia_nim":
- return NvidiaNimProvider(config, nim_settings=settings.nim)
- if provider_id == "open_router":
- if settings.openrouter_transport == "openai":
- return OpenRouterChatProvider(config)
- return OpenRouterProvider(config)
- if provider_id == "deepseek":
- return DeepSeekProvider(config)
- if provider_id == "lmstudio":
- return LMStudioProvider(config)
- if provider_id == "llamacpp":
- return LlamaCppProvider(config)
-
- raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
+ factory = PROVIDER_FACTORIES.get(provider_id)
+ if factory is None:
+ raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
+ return factory(config, settings)
class ProviderRegistry:
diff --git a/pyproject.toml b/pyproject.toml
index 7b145bb..3dd2dd9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,7 @@ voice_local = [
]
[tool.hatch.build.targets.wheel]
-packages = ["api", "cli", "config", "messaging", "providers"]
+packages = ["api", "cli", "config", "core", "messaging", "providers"]
[tool.uv.sources]
torch = { index = "pytorch-cu130" }
@@ -86,7 +86,7 @@ ignore = [
]
[tool.ruff.lint.isort]
-known-first-party = ["api", "cli", "config", "messaging", "providers", "utils"]
+known-first-party = ["api", "cli", "config", "core", "messaging", "providers", "smoke"]
[tool.ruff.format]
quote-style = "double"
diff --git a/smoke/README.md b/smoke/README.md
index 8e5dfaf..769a9c2 100644
--- a/smoke/README.md
+++ b/smoke/README.md
@@ -18,7 +18,7 @@ belong under `tests/` and must stay green with plain `uv run pytest`.
```powershell
uv run pytest smoke --collect-only -q
-uv run pytest smoke -n 0 -m live -s --tb=short
+uv run pytest smoke -n 0 -s --tb=short
```
The second command skips everything unless `FCC_LIVE_SMOKE=1` is set, but still
@@ -28,7 +28,7 @@ writes skip entries to `.smoke-results/`.
```powershell
$env:FCC_LIVE_SMOKE = "1"
-uv run pytest smoke -n 0 -m live -s --tb=short
+uv run pytest smoke -n 0 -s --tb=short
```
Provider product E2E runs for every configured provider model from `MODEL`,
@@ -68,20 +68,20 @@ Side-effectful targets are opt-in:
```powershell
$env:FCC_LIVE_SMOKE = "1"
$env:FCC_SMOKE_PROVIDER_MATRIX = "open_router,nvidia_nim,deepseek,lmstudio,llamacpp"
-uv run pytest smoke/product -n 0 -m live -s --tb=short
+uv run pytest smoke/product -n 0 -s --tb=short
```
```powershell
$env:FCC_LIVE_SMOKE = "1"
$env:FCC_SMOKE_TARGETS = "telegram,discord,voice"
$env:FCC_SMOKE_RUN_VOICE = "1"
-uv run pytest smoke/product -n 0 -m live -s --tb=short
+uv run pytest smoke/product -n 0 -s --tb=short
```
```powershell
$env:FCC_LIVE_SMOKE = "1"
$env:FCC_SMOKE_TARGETS = "messaging,config,extensibility"
-uv run pytest smoke/product -n 0 -m live -s --tb=short
+uv run pytest smoke/product -n 0 -s --tb=short
```
## Environment
diff --git a/smoke/capabilities.py b/smoke/capabilities.py
index 86e1fd6..30772a7 100644
--- a/smoke/capabilities.py
+++ b/smoke/capabilities.py
@@ -149,7 +149,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
"streaming_conversion",
"anthropic_sse_lifecycle",
"streaming_error_mapping",
- "providers.common.SSEBuilder",
+ "core.anthropic.SSEBuilder",
"provider stream chunks or native SSE events",
"Anthropic message/content/error SSE lifecycle",
"Anthropic-compatible error event",
@@ -162,7 +162,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
"streaming_conversion",
"thinking_blocks",
"thinking_token_support",
- "providers.common.think_parser",
+ "core.anthropic.thinking",
"reasoning_content, reasoning_details, text, native thinking",
"Claude thinking blocks or suppression",
"thinking hidden when disabled",
@@ -175,7 +175,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
"streaming_conversion",
"heuristic_tools",
"heuristic_tool_parser",
- "providers.common.heuristic_tool_parser",
+ "core.anthropic.tools",
"textual tool-call output",
"structured Anthropic tool_use blocks",
"text fallback when malformed",
@@ -186,7 +186,7 @@ CAPABILITY_CONTRACTS: tuple[CapabilityContract, ...] = (
"streaming_conversion",
"subagent_task_control",
"subagent_control",
- "providers.common.SSEBuilder",
+ "core.anthropic.SSEBuilder",
"Task tool call arguments",
"run_in_background=false",
"invalid JSON flushed as safe object",
diff --git a/smoke/lib/report_summary.py b/smoke/lib/report_summary.py
new file mode 100644
index 0000000..3c91644
--- /dev/null
+++ b/smoke/lib/report_summary.py
@@ -0,0 +1,53 @@
+"""Summarize smoke JSON reports for local and workflow triage."""
+
+from __future__ import annotations
+
+import json
+from collections import Counter
+from dataclasses import dataclass
+from pathlib import Path
+
+
+@dataclass(frozen=True, slots=True)
+class SmokeSummary:
+ reports: int
+ outcomes: int
+ classifications: dict[str, int]
+
+ @property
+ def has_regression(self) -> bool:
+ return bool(
+ self.classifications.get("product_failure", 0)
+ or self.classifications.get("harness_bug", 0)
+ )
+
+
+def summarize_reports(results_dir: Path) -> SmokeSummary:
+ """Read all report JSON files and count outcome classifications."""
+ counts: Counter[str] = Counter()
+ reports = 0
+ outcomes = 0
+ for path in sorted(results_dir.glob("report-*.json")):
+ reports += 1
+ payload = json.loads(path.read_text(encoding="utf-8"))
+ for outcome in payload.get("outcomes", []):
+ if not isinstance(outcome, dict):
+ continue
+ outcomes += 1
+ counts[str(outcome.get("classification") or "unknown")] += 1
+ return SmokeSummary(
+ reports=reports,
+ outcomes=outcomes,
+ classifications=dict(sorted(counts.items())),
+ )
+
+
+def format_summary(summary: SmokeSummary) -> str:
+ """Return a compact human-readable summary."""
+ parts = [
+ f"reports={summary.reports}",
+ f"outcomes={summary.outcomes}",
+ ]
+ parts.extend(f"{name}={count}" for name, count in summary.classifications.items())
+ status = "regression" if summary.has_regression else "ok"
+ return f"smoke_summary status={status} " + " ".join(parts)
diff --git a/tests/api/test_api.py b/tests/api/test_api.py
index 9f83a64..cc53294 100644
--- a/tests/api/test_api.py
+++ b/tests/api/test_api.py
@@ -96,7 +96,6 @@ def test_model_mapping():
assert len(_stream_response_calls) == 1
args = _stream_response_calls[0][0]
assert args[0].model != "claude-3-haiku-20240307"
- assert args[0].original_model == "claude-3-haiku-20240307"
def test_error_fallbacks():
diff --git a/tests/api/test_model_router.py b/tests/api/test_model_router.py
index 52d107b..e51324d 100644
--- a/tests/api/test_model_router.py
+++ b/tests/api/test_model_router.py
@@ -36,9 +36,9 @@ def test_model_router_applies_opus_override(settings):
)
routed = ModelRouter(settings).resolve_messages_request(request)
- assert routed.model == "deepseek/deepseek-r1"
- assert routed.resolved_provider_model == "open_router/deepseek/deepseek-r1"
- assert routed.original_model == "claude-opus-4-20250514"
+ assert routed.request.model == "deepseek/deepseek-r1"
+ assert routed.resolved.provider_model_ref == "open_router/deepseek/deepseek-r1"
+ assert routed.resolved.original_model == "claude-opus-4-20250514"
assert request.model == "claude-opus-4-20250514"
@@ -53,8 +53,8 @@ def test_model_router_applies_haiku_override(settings):
)
)
- assert routed.model == "qwen2.5-7b"
- assert routed.resolved_provider_model == "lmstudio/qwen2.5-7b"
+ assert routed.request.model == "qwen2.5-7b"
+ assert routed.resolved.provider_model_ref == "lmstudio/qwen2.5-7b"
def test_model_router_applies_sonnet_override(settings):
@@ -68,8 +68,10 @@ def test_model_router_applies_sonnet_override(settings):
)
)
- assert routed.model == "meta/llama-3.3-70b-instruct"
- assert routed.resolved_provider_model == "nvidia_nim/meta/llama-3.3-70b-instruct"
+ assert routed.request.model == "meta/llama-3.3-70b-instruct"
+ assert (
+ routed.resolved.provider_model_ref == "nvidia_nim/meta/llama-3.3-70b-instruct"
+ )
def test_model_router_routes_token_count_request(settings):
@@ -81,7 +83,7 @@ def test_model_router_routes_token_count_request(settings):
)
routed = ModelRouter(settings).resolve_token_count_request(request)
- assert routed.model == "qwen2.5-7b"
+ assert routed.request.model == "qwen2.5-7b"
assert request.model == "claude-3-haiku-20240307"
diff --git a/tests/api/test_models_validators.py b/tests/api/test_models_validators.py
index d74da70..4e7740b 100644
--- a/tests/api/test_models_validators.py
+++ b/tests/api/test_models_validators.py
@@ -9,22 +9,22 @@ def test_messages_request_parses_without_model_mapping_side_effects():
)
assert request.model == "claude-3-opus"
- assert request.original_model is None
- assert request.resolved_provider_model is None
-def test_messages_request_preserves_internal_routing_fields_when_supplied():
- request = MessagesRequest(
- model="target-model",
- original_model="claude-3-opus",
- resolved_provider_model="nvidia_nim/target-model",
- max_tokens=100,
- messages=[Message(role="user", content="hello")],
+def test_messages_request_ignores_internal_routing_fields_when_supplied():
+ request = MessagesRequest.model_validate(
+ {
+ "model": "target-model",
+ "original_model": "claude-3-opus",
+ "resolved_provider_model": "nvidia_nim/target-model",
+ "max_tokens": 100,
+ "messages": [{"role": "user", "content": "hello"}],
+ }
)
assert request.model == "target-model"
- assert request.original_model == "claude-3-opus"
- assert request.resolved_provider_model == "nvidia_nim/target-model"
+ assert "original_model" not in request.model_dump()
+ assert "resolved_provider_model" not in request.model_dump()
def test_token_count_request_parses_without_model_mapping_side_effects():
diff --git a/tests/cli/test_cli_ownership.py b/tests/cli/test_cli_ownership.py
new file mode 100644
index 0000000..7d16876
--- /dev/null
+++ b/tests/cli/test_cli_ownership.py
@@ -0,0 +1,29 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+from cli.session import CLISession
+
+
+def test_cli_session_owns_typed_runner_config(tmp_path: Path) -> None:
+ session = CLISession(
+ workspace_path=str(tmp_path),
+ api_url="http://127.0.0.1:8082/v1",
+ allowed_dirs=[str(tmp_path)],
+ plans_directory=".plans",
+ claude_bin="claude-test",
+ )
+
+ assert session.config.workspace_path == str(tmp_path)
+ assert session.config.api_url == "http://127.0.0.1:8082/v1"
+ assert session.config.allowed_dirs == [str(tmp_path)]
+ assert session.config.plans_directory == ".plans"
+ assert session.config.claude_bin == "claude-test"
+
+
+def test_claude_pick_uses_project_python_runner() -> None:
+ script = Path(__file__).resolve().parents[2] / "claude-pick"
+ text = script.read_text(encoding="utf-8")
+
+ assert "uv run python" in text
+ assert "python3 -c" not in text
diff --git a/tests/contracts/test_architecture_contracts.py b/tests/contracts/test_architecture_contracts.py
new file mode 100644
index 0000000..ce98264
--- /dev/null
+++ b/tests/contracts/test_architecture_contracts.py
@@ -0,0 +1,48 @@
+from __future__ import annotations
+
+import re
+from pathlib import Path
+
+
+def test_architecture_plan_exists() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ plan = repo_root / "PLAN.md"
+
+ assert plan.exists()
+ text = plan.read_text(encoding="utf-8")
+ assert "Intended Dependency Direction" in text
+ assert "Smoke Coverage Policy" in text
+
+
+def test_env_examples_are_in_sync() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ root_example = repo_root / ".env.example"
+ packaged_example = repo_root / "config" / "env.example"
+
+ assert _env_keys(root_example) == _env_keys(packaged_example)
+
+
+def test_pyproject_first_party_packages_match_packaged_roots() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ pyproject = (repo_root / "pyproject.toml").read_text(encoding="utf-8")
+ match = re.search(r"known-first-party = \[(?P[^\]]+)\]", pyproject)
+
+ assert match is not None
+ configured = {
+ item.strip().strip('"')
+ for item in match.group("items").split(",")
+ if item.strip()
+ }
+ expected = {"api", "cli", "config", "core", "messaging", "providers", "smoke"}
+ assert configured == expected
+
+
+def _env_keys(path: Path) -> set[str]:
+ keys: set[str] = set()
+ for line in path.read_text(encoding="utf-8").splitlines():
+ stripped = line.strip()
+ if not stripped or stripped.startswith("#") or "=" not in stripped:
+ continue
+ key, _, _value = stripped.partition("=")
+ keys.add(key.strip())
+ return keys
diff --git a/tests/contracts/test_import_boundaries.py b/tests/contracts/test_import_boundaries.py
new file mode 100644
index 0000000..bf33eae
--- /dev/null
+++ b/tests/contracts/test_import_boundaries.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import ast
+from pathlib import Path
+
+
+def test_api_and_messaging_do_not_import_provider_common() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ assert not (repo_root / "providers" / "common").exists()
+ offenders = _imports_matching(
+ [repo_root / "api", repo_root / "messaging"],
+ forbidden_prefixes=("providers.common",),
+ )
+
+ assert offenders == []
+
+
+def test_provider_adapters_do_not_import_runtime_layers() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ offenders = _imports_matching(
+ [repo_root / "providers"],
+ forbidden_prefixes=("api.", "messaging.", "cli."),
+ )
+
+ assert offenders == []
+
+
+def test_architecture_doc_names_enforced_boundaries() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ text = (repo_root / "PLAN.md").read_text(encoding="utf-8")
+
+ assert "core/anthropic/" in text
+ assert "api/runtime.py" in text
+ assert "import-boundary" in text or "Provider adapters may depend" in text
+
+
+def _imports_matching(
+ roots: list[Path], *, forbidden_prefixes: tuple[str, ...]
+) -> list[str]:
+ offenders: list[str] = []
+ for root in roots:
+ for path in root.rglob("*.py"):
+ rel = path.relative_to(root.parent)
+ offenders.extend(
+ f"{rel}: {imported}"
+ for imported in _imports_from(path)
+ if imported in forbidden_prefixes
+ or imported.startswith(forbidden_prefixes)
+ )
+ return sorted(offenders)
+
+
+def _imports_from(path: Path) -> list[str]:
+ tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
+ imports: list[str] = []
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Import):
+ imports.extend(alias.name for alias in node.names)
+ elif isinstance(node, ast.ImportFrom) and node.module:
+ imports.append(node.module)
+ return imports
diff --git a/tests/contracts/test_smoke_tiers.py b/tests/contracts/test_smoke_tiers.py
new file mode 100644
index 0000000..f343d88
--- /dev/null
+++ b/tests/contracts/test_smoke_tiers.py
@@ -0,0 +1,34 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+
+from smoke.lib.report_summary import format_summary, summarize_reports
+
+
+def test_smoke_readme_uses_env_gated_serial_commands() -> None:
+ repo_root = Path(__file__).resolve().parents[2]
+ text = (repo_root / "smoke" / "README.md").read_text(encoding="utf-8")
+
+ assert "FCC_LIVE_SMOKE=1" in text
+ assert "-n 0" in text
+ assert "-m live" not in text
+
+
+def test_smoke_report_summary_counts_regression_classes(tmp_path: Path) -> None:
+ report = {
+ "outcomes": [
+ {"classification": "missing_env"},
+ {"classification": "product_failure"},
+ {"classification": "upstream_unavailable"},
+ ]
+ }
+ (tmp_path / "report-one.json").write_text(json.dumps(report), encoding="utf-8")
+
+ summary = summarize_reports(tmp_path)
+
+ assert summary.reports == 1
+ assert summary.outcomes == 3
+ assert summary.classifications["product_failure"] == 1
+ assert summary.has_regression
+ assert "status=regression" in format_summary(summary)
diff --git a/tests/contracts/test_stream_contracts.py b/tests/contracts/test_stream_contracts.py
index cff9eb3..f6e964f 100644
--- a/tests/contracts/test_stream_contracts.py
+++ b/tests/contracts/test_stream_contracts.py
@@ -2,14 +2,9 @@ from __future__ import annotations
from collections.abc import Iterable
+from core.anthropic import ContentType, HeuristicToolParser, SSEBuilder, ThinkTagParser
from messaging.event_parser import parse_cli_event
from messaging.transcript import RenderCtx, TranscriptBuffer
-from providers.common import (
- ContentType,
- HeuristicToolParser,
- SSEBuilder,
- ThinkTagParser,
-)
from smoke.lib.sse import (
assert_anthropic_stream_contract,
event_names,
diff --git a/tests/messaging/test_extract_text.py b/tests/messaging/test_extract_text.py
index b223459..417f9ef 100644
--- a/tests/messaging/test_extract_text.py
+++ b/tests/messaging/test_extract_text.py
@@ -4,11 +4,11 @@ from unittest.mock import MagicMock
import pytest
-from providers.common.text import extract_text_from_content
+from core.anthropic import extract_text_from_content
class TestExtractTextFromContent:
- """Tests for providers.common.text.extract_text_from_content."""
+ """Tests for core.anthropic.extract_text_from_content."""
def test_string_content(self):
"""Return string content as-is."""
diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py
index 95ca23f..f9a0faf 100644
--- a/tests/providers/test_converter.py
+++ b/tests/providers/test_converter.py
@@ -2,7 +2,7 @@ import json
import pytest
-from providers.common.message_converter import AnthropicToOpenAIConverter
+from core.anthropic import AnthropicToOpenAIConverter
# --- Mock Classes ---
@@ -314,7 +314,7 @@ def test_convert_mixed_blocks_and_types_and_roles():
def test_get_block_attr_defaults():
# Test helper directly
- from providers.common.message_converter import get_block_attr
+ from core.anthropic import get_block_attr
assert get_block_attr({}, "missing", "default") == "default"
assert get_block_attr(object(), "missing", "default") == "default"
diff --git a/tests/providers/test_error_mapping.py b/tests/providers/test_error_mapping.py
index 7e6cf99..162b05d 100644
--- a/tests/providers/test_error_mapping.py
+++ b/tests/providers/test_error_mapping.py
@@ -1,4 +1,4 @@
-"""Tests for providers/nvidia_nim/errors.py error mapping."""
+"""Tests for provider error mapping and core error formatting."""
from unittest.mock import MagicMock, patch
@@ -6,7 +6,8 @@ import openai
import pytest
from httpx import ReadTimeout, Request, Response
-from providers.common import append_request_id, get_user_facing_error_message, map_error
+from core.anthropic import append_request_id, get_user_facing_error_message
+from providers.error_mapping import map_error
from providers.exceptions import (
APIError,
AuthenticationError,
@@ -41,7 +42,7 @@ class TestMapError:
def test_rate_limit_error(self):
"""openai.RateLimitError -> RateLimitError and triggers global block."""
exc = _make_openai_error(openai.RateLimitError, status_code=429)
- with patch("providers.common.error_mapping.GlobalRateLimiter") as mock_rl:
+ with patch("providers.error_mapping.GlobalRateLimiter") as mock_rl:
mock_instance = MagicMock()
mock_rl.get_instance.return_value = mock_instance
result = map_error(exc)
@@ -117,7 +118,7 @@ class TestMapError:
openai.BadRequestError: 400,
}
exc = _make_openai_error(exc_cls, status_code=status_map[exc_cls])
- with patch("providers.common.error_mapping.GlobalRateLimiter"):
+ with patch("providers.error_mapping.GlobalRateLimiter"):
result = map_error(exc)
assert isinstance(result, expected_cls)
diff --git a/tests/providers/test_nvidia_nim_request.py b/tests/providers/test_nvidia_nim_request.py
index e4b9953..df38e24 100644
--- a/tests/providers/test_nvidia_nim_request.py
+++ b/tests/providers/test_nvidia_nim_request.py
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest
from config.nim import NimSettings
-from providers.common.utils import set_if_not_none
+from core.anthropic import set_if_not_none
from providers.nvidia_nim.request import (
_set_extra,
build_request_body,
diff --git a/tests/providers/test_parsers.py b/tests/providers/test_parsers.py
index e84edaa..0b36dc7 100644
--- a/tests/providers/test_parsers.py
+++ b/tests/providers/test_parsers.py
@@ -1,6 +1,6 @@
import pytest
-from providers.common import ContentType, HeuristicToolParser, ThinkTagParser
+from core.anthropic import ContentType, HeuristicToolParser, ThinkTagParser
def test_think_tag_parser_basic():
diff --git a/tests/providers/test_sse_builder.py b/tests/providers/test_sse_builder.py
index 5c2400f..07df2f3 100644
--- a/tests/providers/test_sse_builder.py
+++ b/tests/providers/test_sse_builder.py
@@ -1,15 +1,11 @@
-"""Tests for providers/nvidia_nim/utils/sse_builder.py."""
+"""Tests for core.anthropic.sse."""
import json
from unittest.mock import patch
import pytest
-from providers.common.sse_builder import (
- ContentBlockManager,
- SSEBuilder,
- map_stop_reason,
-)
+from core.anthropic import ContentBlockManager, SSEBuilder, map_stop_reason
def _parse_sse(sse_str: str) -> dict:
@@ -366,7 +362,7 @@ class TestSSEBuilderTokenEstimation:
builder.start_text_block()
builder.emit_text_delta("a" * 100) # 100 chars -> ~25 tokens
- with patch("providers.common.sse_builder.ENCODER", None):
+ with patch("core.anthropic.sse.ENCODER", None):
tokens = builder.estimate_output_tokens()
assert tokens == 25 # 100 // 4
@@ -376,7 +372,7 @@ class TestSSEBuilderTokenEstimation:
builder.start_tool_block(0, "t1", "Read")
builder.emit_tool_delta(0, '{"path":"test.py"}')
- with patch("providers.common.sse_builder.ENCODER", None):
+ with patch("core.anthropic.sse.ENCODER", None):
tokens = builder.estimate_output_tokens()
# 1 tool * 50 = 50
assert tokens == 50
diff --git a/tests/providers/test_streaming_errors.py b/tests/providers/test_streaming_errors.py
index 04e9ee9..0b5f614 100644
--- a/tests/providers/test_streaming_errors.py
+++ b/tests/providers/test_streaming_errors.py
@@ -392,7 +392,7 @@ class TestProcessToolCall:
def test_tool_call_with_id(self):
"""Tool call with id starts a tool block."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
@@ -409,7 +409,7 @@ class TestProcessToolCall:
def test_tool_call_without_id_generates_uuid(self):
"""Tool call without id generates a uuid-based id."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
@@ -424,7 +424,7 @@ class TestProcessToolCall:
def test_task_tool_forces_background_false(self):
"""Task tool with run_in_background=true is forced to false."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
args = json.dumps({"run_in_background": True, "prompt": "test"})
@@ -441,7 +441,7 @@ class TestProcessToolCall:
def test_task_tool_chunked_args_forces_background_false(self):
"""Chunked Task args are buffered until valid JSON, then forced to false."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc1 = {
@@ -466,7 +466,7 @@ class TestProcessToolCall:
def test_task_tool_invalid_json_logs_warning_on_flush(self, caplog):
"""Invalid JSON args for Task tool emits {} on flush and logs a warning."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
@@ -486,7 +486,7 @@ class TestProcessToolCall:
def test_negative_tool_index_fallback(self):
"""tc_index < 0 uses len(tool_indices) as fallback."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
@@ -501,7 +501,7 @@ class TestProcessToolCall:
def test_tool_args_emitted_as_delta(self):
"""Arguments are emitted as input_json_delta events."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
@@ -621,7 +621,7 @@ class TestStreamChunkEdgeCases:
def test_stream_malformed_tool_args_chunked(self):
"""Chunked tool args that never form valid JSON are flushed with {}."""
provider = _make_provider()
- from providers.common import SSEBuilder
+ from core.anthropic import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc1 = {
diff --git a/tests/providers/test_subagent_interception.py b/tests/providers/test_subagent_interception.py
index a36897a..136c3f4 100644
--- a/tests/providers/test_subagent_interception.py
+++ b/tests/providers/test_subagent_interception.py
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock
import pytest
from config.nim import NimSettings
+from core.anthropic import ContentBlockManager
from providers.base import ProviderConfig
-from providers.common import ContentBlockManager
from providers.nvidia_nim import NvidiaNimProvider