Backup/before cleanup 20260222 230402 (#58)

This commit is contained in:
Ali Khokhar 2026-02-27 19:50:21 -08:00 committed by GitHub
parent e2840095ce
commit c4d8681000
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 253 additions and 584 deletions

View file

@ -29,7 +29,7 @@ HTTP_CONNECT_TIMEOUT=2
# Messaging Platform: "telegram" | "discord" # Messaging Platform: "telegram" | "discord"
MESSAGING_PLATFORM=discord MESSAGING_PLATFORM="discord"
MESSAGING_RATE_LIMIT=1 MESSAGING_RATE_LIMIT=1
MESSAGING_RATE_WINDOW=1 MESSAGING_RATE_WINDOW=1

View file

@ -5,6 +5,7 @@
## CODING ENVIRONMENT ## CODING ENVIRONMENT
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed - Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
- Always use `uv run` to run files instead of the global `python` command. - Always use `uv run` to run files instead of the global `python` command.
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
- Read `.env.example` for environment variables. - Read `.env.example` for environment variables.
- All CI checks must pass; failing checks block merge. - All CI checks must pass; failing checks block merge.
- Add tests for new changes (including edge cases), then run `uv run pytest`. - Add tests for new changes (including edge cases), then run `uv run pytest`.
@ -38,7 +39,7 @@
## SUMMARY STANDARDS ## SUMMARY STANDARDS
- Summaries must be technical and granular. - Summaries must be technical and granular.
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks]. - Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
## TOOLS ## TOOLS
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use. - Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.

View file

@ -5,6 +5,7 @@
## CODING ENVIRONMENT ## CODING ENVIRONMENT
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed - Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
- Always use `uv run` to run files instead of the global `python` command. - Always use `uv run` to run files instead of the global `python` command.
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
- Read `.env.example` for environment variables. - Read `.env.example` for environment variables.
- All CI checks must pass; failing checks block merge. - All CI checks must pass; failing checks block merge.
- Add tests for new changes (including edge cases), then run `uv run pytest`. - Add tests for new changes (including edge cases), then run `uv run pytest`.
@ -38,7 +39,7 @@
## SUMMARY STANDARDS ## SUMMARY STANDARDS
- Summaries must be technical and granular. - Summaries must be technical and granular.
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks]. - Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
## TOOLS ## TOOLS
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use. - Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.

View file

@ -289,11 +289,11 @@ uv sync --extra voice
Full list in [`nvidia_nim_models.json`](nvidia_nim_models.json). Full list in [`nvidia_nim_models.json`](nvidia_nim_models.json).
Popular models: Popular models:
- `qwen/qwen3.5-397b-a17b` - `nvidia_nim/minimaxai/minimax-m2.5`
- `z-ai/glm5` - `nvidia_nim/qwen/qwen3.5-397b-a17b`
- `stepfun-ai/step-3.5-flash` - `nvidia_nim/z-ai/glm5`
- `moonshotai/kimi-k2.5` - `nvidia_nim/stepfun-ai/step-3.5-flash`
- `minimaxai/minimax-m2.1` - `nvidia_nim/moonshotai/kimi-k2.5`
Browse: [build.nvidia.com](https://build.nvidia.com/explore/discover) Browse: [build.nvidia.com](https://build.nvidia.com/explore/discover)
@ -310,9 +310,9 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
Hundreds of models from StepFun, OpenAI, Anthropic, Google, and more. Hundreds of models from StepFun, OpenAI, Anthropic, Google, and more.
Popular models: Popular models:
- `stepfun/step-3.5-flash:free` - `open_router/stepfun/step-3.5-flash:free`
- `deepseek/deepseek-r1-0528:free` - `open_router/deepseek/deepseek-r1-0528:free`
- `openai/gpt-oss-120b:free` - `open_router/openai/gpt-oss-120b:free`
Browse: [openrouter.ai/models](https://openrouter.ai/models) Browse: [openrouter.ai/models](https://openrouter.ai/models)
@ -385,7 +385,6 @@ free-claude-code/
├── messaging/ # MessagingPlatform ABC + Discord/Telegram bots, session management ├── messaging/ # MessagingPlatform ABC + Discord/Telegram bots, session management
├── config/ # Settings, NIM config, logging ├── config/ # Settings, NIM config, logging
├── cli/ # CLI session and process management ├── cli/ # CLI session and process management
├── utils/ # Text utilities
└── tests/ # Pytest test suite └── tests/ # Pytest test suite
``` ```

View file

@ -51,7 +51,7 @@ async def lifespan(app: FastAPI):
try: try:
# Use the messaging factory to create the right platform # Use the messaging factory to create the right platform
from messaging.factory import create_messaging_platform from messaging.platforms.factory import create_messaging_platform
messaging_platform = create_messaging_platform( messaging_platform = create_messaging_platform(
platform_type=settings.messaging_platform, platform_type=settings.messaging_platform,

View file

@ -16,88 +16,84 @@ def get_settings() -> Settings:
return _get_settings() return _get_settings()
def _create_provider(settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance from settings."""
if settings.provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise HTTPException(
status_code=503,
detail=(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
),
)
from providers.nvidia_nim import NvidiaNimProvider
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
elif settings.provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise HTTPException(
status_code=503,
detail=(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
),
)
from providers.open_router import OpenRouterProvider
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url="https://openrouter.ai/api/v1",
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = OpenRouterProvider(config)
elif settings.provider_type == "lmstudio":
from providers.lmstudio import LMStudioProvider
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = LMStudioProvider(config)
else:
logger.error(
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
logger.info("Provider initialized: %s", settings.provider_type)
return provider
def get_provider() -> BaseProvider: def get_provider() -> BaseProvider:
"""Get or create the provider instance based on settings.provider_type.""" """Get or create the provider instance based on settings.provider_type."""
global _provider global _provider
if _provider is None: if _provider is None:
settings = get_settings() _provider = _create_provider(get_settings())
if settings.provider_type == "nvidia_nim":
if (
not settings.nvidia_nim_api_key
or not settings.nvidia_nim_api_key.strip()
):
raise HTTPException(
status_code=503,
detail=(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
),
)
from providers.nvidia_nim import NvidiaNimProvider
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
_provider = NvidiaNimProvider(config, nim_settings=settings.nim)
logger.info("Provider initialized: %s", settings.provider_type)
elif settings.provider_type == "open_router":
if (
not settings.open_router_api_key
or not settings.open_router_api_key.strip()
):
raise HTTPException(
status_code=503,
detail=(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
),
)
from providers.open_router import OpenRouterProvider
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url="https://openrouter.ai/api/v1",
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
_provider = OpenRouterProvider(config)
logger.info("Provider initialized: %s", settings.provider_type)
elif settings.provider_type == "lmstudio":
from providers.lmstudio import LMStudioProvider
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
_provider = LMStudioProvider(config)
logger.info("Provider initialized: %s", settings.provider_type)
else:
logger.error(
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
return _provider return _provider
@ -108,5 +104,10 @@ async def cleanup_provider():
client = getattr(_provider, "_client", None) client = getattr(_provider, "_client", None)
if client and hasattr(client, "aclose"): if client and hasattr(client, "aclose"):
await client.aclose() await client.aclose()
elif client:
logger.warning(
"Provider client %r has no aclose(); skipping async cleanup",
type(client).__name__,
)
_provider = None _provider = None
logger.debug("Provider cleanup completed") logger.debug("Provider cleanup completed")

View file

@ -29,14 +29,13 @@ def is_quota_check_request(request_data: MessagesRequest) -> bool:
def is_title_generation_request(request_data: MessagesRequest) -> bool: def is_title_generation_request(request_data: MessagesRequest) -> bool:
"""Check if this is a conversation title generation request. """Check if this is a conversation title generation request.
Title generation requests typically contain the phrase Title generation requests are detected by a system prompt containing
"write a 5-10 word title" in the user's message. title extraction instructions, no tools, and a single user message.
""" """
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user": if not request_data.system or request_data.tools:
text = extract_text_from_content(request_data.messages[-1].content) return False
if "write a 5-10 word title" in text.lower(): system_text = extract_text_from_content(request_data.system).lower()
return True return "new conversation topic" in system_text and "title" in system_text
return False
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]: def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:

View file

@ -7,7 +7,6 @@ from loguru import logger
from pydantic import BaseModel, field_validator, model_validator from pydantic import BaseModel, field_validator, model_validator
from config.settings import get_settings from config.settings import get_settings
from providers.model_utils import normalize_model_name
# ============================================================================= # =============================================================================
# Content Block Types # Content Block Types
@ -112,10 +111,7 @@ class MessagesRequest(BaseModel):
if self.original_model is None: if self.original_model is None:
self.original_model = self.model self.original_model = self.model
# Use centralized model normalization self.model = settings.model_name
normalized = normalize_model_name(self.model, settings.model_name)
if normalized != self.model:
self.model = normalized
if self.model != self.original_model: if self.model != self.original_model:
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'") logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
@ -136,5 +132,4 @@ class TokenCountRequest(BaseModel):
def validate_model_field(cls, v, info): def validate_model_field(cls, v, info):
"""Map any Claude model name to the configured model.""" """Map any Claude model name to the configured model."""
settings = get_settings() settings = get_settings()
# Use centralized model normalization return settings.model_name
return normalize_model_name(v, settings.model_name)

View file

@ -80,6 +80,7 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
format=_serialize_with_context, format=_serialize_with_context,
encoding="utf-8", encoding="utf-8",
mode="a", mode="a",
rotation="50 MB",
) )
# Intercept stdlib logging: route all root logger output to loguru # Intercept stdlib logging: route all root logger output to loguru

View file

@ -1,9 +1,9 @@
"""Platform-agnostic messaging layer.""" """Platform-agnostic messaging layer."""
from .base import CLISession, MessagingPlatform, SessionManagerInterface
from .event_parser import parse_cli_event from .event_parser import parse_cli_event
from .handler import ClaudeMessageHandler from .handler import ClaudeMessageHandler
from .models import IncomingMessage from .models import IncomingMessage
from .platforms.base import CLISession, MessagingPlatform, SessionManagerInterface
from .session import SessionStore from .session import SessionStore
from .trees.data import MessageNode, MessageState, MessageTree from .trees.data import MessageNode, MessageState, MessageTree
from .trees.queue_manager import TreeQueueManager from .trees.queue_manager import TreeQueueManager

View file

@ -1,9 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.base for new code."""
from .platforms.base import (
CLISession,
MessagingPlatform,
SessionManagerInterface,
)
__all__ = ["CLISession", "MessagingPlatform", "SessionManagerInterface"]

View file

@ -1,17 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.discord for new code."""
from .platforms.discord import (
DISCORD_AVAILABLE,
DISCORD_MESSAGE_LIMIT,
DiscordPlatform,
_get_discord,
_parse_allowed_channels,
)
__all__ = [
"DISCORD_AVAILABLE",
"DISCORD_MESSAGE_LIMIT",
"DiscordPlatform",
"_get_discord",
"_parse_allowed_channels",
]

View file

@ -1,25 +0,0 @@
"""Backward-compatible re-export. Use messaging.rendering.discord_markdown for new code."""
from .rendering.discord_markdown import (
_is_gfm_table_header_line,
_normalize_gfm_tables,
discord_bold,
discord_code_inline,
escape_discord,
escape_discord_code,
format_status,
format_status_discord,
render_markdown_to_discord,
)
__all__ = [
"_is_gfm_table_header_line",
"_normalize_gfm_tables",
"discord_bold",
"discord_code_inline",
"escape_discord",
"escape_discord_code",
"format_status",
"format_status_discord",
"render_markdown_to_discord",
]

View file

@ -1,5 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.factory for new code."""
from .platforms.factory import create_messaging_platform
__all__ = ["create_messaging_platform"]

View file

@ -95,18 +95,18 @@ class SessionStore:
except Exception as e: except Exception as e:
logger.error(f"Failed to load sessions: {e}") logger.error(f"Failed to load sessions: {e}")
def _save(self) -> None: def _snapshot(self) -> dict:
"""Persist sessions and trees to disk. Caller must hold self._lock.""" """Snapshot current state for serialization. Caller must hold self._lock."""
try: return {
data = { "trees": dict(self._trees),
"trees": self._trees, "node_to_tree": dict(self._node_to_tree),
"node_to_tree": self._node_to_tree, "message_log": {k: list(v) for k, v in self._message_log.items()},
"message_log": self._message_log, }
}
with open(self.storage_path, "w", encoding="utf-8") as f: def _write_data(self, data: dict) -> None:
json.dump(data, f, indent=2) """Write data dict to disk. Must be called WITHOUT holding self._lock."""
except Exception as e: with open(self.storage_path, "w", encoding="utf-8") as f:
logger.error(f"Failed to save sessions: {e}") json.dump(data, f, indent=2)
def _schedule_save(self) -> None: def _schedule_save(self) -> None:
"""Schedule a debounced save. Caller must hold self._lock.""" """Schedule a debounced save. Caller must hold self._lock."""
@ -126,22 +126,35 @@ class SessionStore:
if not self._dirty: if not self._dirty:
self._save_timer = None self._save_timer = None
return return
self._save() snapshot = self._snapshot()
self._dirty = False self._dirty = False
self._save_timer = None self._save_timer = None
try:
self._write_data(snapshot)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
with self._lock:
self._dirty = True
def _flush_save(self) -> None: def _flush_save(self) -> dict:
"""Immediate save, cancel any pending debounced save. Caller must hold self._lock.""" """Cancel pending timer and snapshot current state. Caller must hold self._lock.
Returns snapshot dict; caller must call _write_data(snapshot) outside the lock."""
if self._save_timer is not None: if self._save_timer is not None:
self._save_timer.cancel() self._save_timer.cancel()
self._save_timer = None self._save_timer = None
self._dirty = False self._dirty = False
self._save() return self._snapshot()
def flush_pending_save(self) -> None: def flush_pending_save(self) -> None:
"""Flush any pending debounced save. Call on shutdown to avoid losing data.""" """Flush any pending debounced save. Call on shutdown to avoid losing data."""
with self._lock: with self._lock:
self._flush_save() snapshot = self._flush_save()
try:
self._write_data(snapshot)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
with self._lock:
self._dirty = True
def record_message_id( def record_message_id(
self, self,
@ -201,7 +214,13 @@ class SessionStore:
self._node_to_tree.clear() self._node_to_tree.clear()
self._message_log.clear() self._message_log.clear()
self._message_log_ids.clear() self._message_log_ids.clear()
self._flush_save() snapshot = self._flush_save()
try:
self._write_data(snapshot)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
with self._lock:
self._dirty = True
# ==================== Tree Methods ==================== # ==================== Tree Methods ====================

View file

@ -1,15 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.telegram for new code."""
from .platforms.telegram import (
TELEGRAM_AVAILABLE,
TelegramPlatform,
)
# Re-export telegram.error types when python-telegram-bot is installed
__all__ = ["TELEGRAM_AVAILABLE", "TelegramPlatform"]
try:
from telegram.error import NetworkError, RetryAfter, TelegramError
__all__ += ["NetworkError", "RetryAfter", "TelegramError"]
except ImportError:
pass

View file

@ -1,21 +0,0 @@
"""Backward-compatible re-export. Use messaging.rendering.telegram_markdown for new code."""
from .rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
escape_md_v2_link_url,
format_status,
mdv2_bold,
mdv2_code_inline,
render_markdown_to_mdv2,
)
__all__ = [
"escape_md_v2",
"escape_md_v2_code",
"escape_md_v2_link_url",
"format_status",
"mdv2_bold",
"mdv2_code_inline",
"render_markdown_to_mdv2",
]

View file

@ -105,17 +105,19 @@ class TreeRepository:
if not tree: if not tree:
return [] return []
pending = [] pending: list[MessageNode] = []
node = tree.get_node(node_id) stack = [node_id]
if not node:
return []
for child_id in node.children_ids: while stack:
child = tree.get_node(child_id) current_id = stack.pop()
if child and child.state == MessageState.PENDING: node = tree.get_node(current_id)
pending.append(child) if not node:
# Recursively get children of pending children continue
pending.extend(self.get_pending_children(child_id)) for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
stack.append(child_id)
return pending return pending

View file

@ -476,13 +476,13 @@
"owned_by": "microsoft" "owned_by": "microsoft"
}, },
{ {
"id": "minimaxai/minimax-m2", "id": "minimaxai/minimax-m2.1",
"object": "model", "object": "model",
"created": 735790403, "created": 735790403,
"owned_by": "minimaxai" "owned_by": "minimaxai"
}, },
{ {
"id": "minimaxai/minimax-m2.1", "id": "minimaxai/minimax-m2.5",
"object": "model", "object": "model",
"created": 735790403, "created": 735790403,
"owned_by": "minimaxai" "owned_by": "minimaxai"
@ -709,12 +709,6 @@
"created": 735790403, "created": 735790403,
"owned_by": "nvidia" "owned_by": "nvidia"
}, },
{
"id": "nvidia/llama-3.2-nemoretriever-300m-embed-v2",
"object": "model",
"created": 735790403,
"owned_by": "nvidia"
},
{ {
"id": "nvidia/llama-3.2-nv-embedqa-1b-v1", "id": "nvidia/llama-3.2-nv-embedqa-1b-v1",
"object": "model", "object": "model",

View file

@ -9,6 +9,7 @@ from .message_converter import (
) )
from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason
from .think_parser import ContentChunk, ContentType, ThinkTagParser from .think_parser import ContentChunk, ContentType, ThinkTagParser
from .utils import set_if_not_none
__all__ = [ __all__ = [
"AnthropicToOpenAIConverter", "AnthropicToOpenAIConverter",
@ -22,4 +23,5 @@ __all__ = [
"get_block_type", "get_block_type",
"map_error", "map_error",
"map_stop_reason", "map_stop_reason",
"set_if_not_none",
] ]

View file

@ -0,0 +1,9 @@
"""Shared utility helpers for provider request builders."""
from typing import Any
def set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
"""Set body[key] = value only when value is not None."""
if value is not None:
body[key] = value

View file

@ -5,15 +5,11 @@ from typing import Any
from loguru import logger from loguru import logger
from providers.common.message_converter import AnthropicToOpenAIConverter from providers.common.message_converter import AnthropicToOpenAIConverter
from providers.common.utils import set_if_not_none
LMSTUDIO_DEFAULT_MAX_TOKENS = 81920 LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
def build_request_body(request_data: Any) -> dict: def build_request_body(request_data: Any) -> dict:
"""Build OpenAI-format request body from Anthropic request for LM Studio.""" """Build OpenAI-format request body from Anthropic request for LM Studio."""
logger.debug( logger.debug(
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
} }
max_tokens = getattr(request_data, "max_tokens", None) max_tokens = getattr(request_data, "max_tokens", None)
_set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS) set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None)) set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None)) set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
stop_sequences = getattr(request_data, "stop_sequences", None) stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences: if stop_sequences:

View file

@ -1,86 +0,0 @@
"""Model name normalization utilities.
Centralizes model name mapping logic to avoid duplication across the codebase.
"""
import os
# Provider prefixes to strip from model names
_PROVIDER_PREFIXES = ["anthropic/", "openai/", "gemini/"]
# Claude model identifiers
_CLAUDE_IDENTIFIERS = ["haiku", "sonnet", "opus", "claude"]
def strip_provider_prefixes(model: str) -> str:
"""
Strip provider prefixes from model name.
Args:
model: The model name, possibly with prefix
Returns:
Model name without provider prefix
"""
for prefix in _PROVIDER_PREFIXES:
if model.startswith(prefix):
return model[len(prefix) :]
return model
def is_claude_model(model: str) -> bool:
"""
Check if a model name identifies as a Claude model.
Args:
model: The (prefix-stripped) model name
Returns:
True if this is a Claude model
"""
model_lower = model.lower()
return any(name in model_lower for name in _CLAUDE_IDENTIFIERS)
def normalize_model_name(model: str, default_model: str | None = None) -> str:
"""
Normalize a model name by stripping prefixes and mapping to default if needed.
This is the central function for model name normalization across the API.
It strips provider prefixes and maps Claude model names to the configured model.
Args:
model: The model name (may include provider prefix)
default_model: The default model to use for Claude models.
If None, uses settings.model from config.
Returns:
Normalized model name (original if not a Claude model, mapped if Claude)
"""
# Strip provider prefixes
clean = strip_provider_prefixes(model)
# Map Claude models to default
if is_claude_model(clean):
if default_model is None:
# Use environment/config default
default_model = os.getenv("MODEL", "moonshotai/kimi-k2-thinking")
return default_model
return model
def get_original_model(model: str) -> str:
"""
Get the original model name, storing it before normalization.
Convenience function that returns the input unchanged, intended to be
called alongside normalize_model_name to capture the original.
Args:
model: The model name
Returns:
The model name unchanged (for documentation purposes)
"""
return model

View file

@ -6,11 +6,7 @@ from loguru import logger
from config.nim import NimSettings from config.nim import NimSettings
from providers.common.message_converter import AnthropicToOpenAIConverter from providers.common.message_converter import AnthropicToOpenAIConverter
from providers.common.utils import set_if_not_none
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
def _set_extra( def _set_extra(
@ -52,15 +48,15 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
max_tokens = nim.max_tokens max_tokens = nim.max_tokens
elif nim.max_tokens: elif nim.max_tokens:
max_tokens = min(max_tokens, nim.max_tokens) max_tokens = min(max_tokens, nim.max_tokens)
_set_if_not_none(body, "max_tokens", max_tokens) set_if_not_none(body, "max_tokens", max_tokens)
req_temperature = getattr(request_data, "temperature", None) req_temperature = getattr(request_data, "temperature", None)
temperature = req_temperature if req_temperature is not None else nim.temperature temperature = req_temperature if req_temperature is not None else nim.temperature
_set_if_not_none(body, "temperature", temperature) set_if_not_none(body, "temperature", temperature)
req_top_p = getattr(request_data, "top_p", None) req_top_p = getattr(request_data, "top_p", None)
top_p = req_top_p if req_top_p is not None else nim.top_p top_p = req_top_p if req_top_p is not None else nim.top_p
_set_if_not_none(body, "top_p", top_p) set_if_not_none(body, "top_p", top_p)
stop_sequences = getattr(request_data, "stop_sequences", None) stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences: if stop_sequences:

View file

@ -5,15 +5,11 @@ from typing import Any
from loguru import logger from loguru import logger
from providers.common.message_converter import AnthropicToOpenAIConverter from providers.common.message_converter import AnthropicToOpenAIConverter
from providers.common.utils import set_if_not_none
OPENROUTER_DEFAULT_MAX_TOKENS = 81920 OPENROUTER_DEFAULT_MAX_TOKENS = 81920
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
def build_request_body(request_data: Any) -> dict: def build_request_body(request_data: Any) -> dict:
"""Build OpenAI-format request body from Anthropic request for OpenRouter.""" """Build OpenAI-format request body from Anthropic request for OpenRouter."""
logger.debug( logger.debug(
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
} }
max_tokens = getattr(request_data, "max_tokens", None) max_tokens = getattr(request_data, "max_tokens", None)
_set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS) set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None)) set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None)) set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
stop_sequences = getattr(request_data, "stop_sequences", None) stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences: if stop_sequences:

View file

@ -124,7 +124,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
patch.object(api_app_mod, "get_settings", return_value=settings), patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch( patch(
"messaging.factory.create_messaging_platform", "messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform if messaging_enabled else None, return_value=fake_platform if messaging_enabled else None,
) as create_platform, ) as create_platform,
patch("messaging.session.SessionStore", return_value=session_store), patch("messaging.session.SessionStore", return_value=session_store),
@ -195,7 +195,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
patch.object(api_app_mod, "get_settings", return_value=settings), patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch( patch(
"messaging.factory.create_messaging_platform", "messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform, return_value=fake_platform,
), ),
patch("messaging.session.SessionStore", return_value=session_store), patch("messaging.session.SessionStore", return_value=session_store),
@ -234,7 +234,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
patch.object(api_app_mod, "get_settings", return_value=settings), patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch( patch(
"messaging.factory.create_messaging_platform", "messaging.platforms.factory.create_messaging_platform",
side_effect=ImportError("discord not installed"), side_effect=ImportError("discord not installed"),
), ),
TestClient(app), TestClient(app),
@ -284,7 +284,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
patch.object(api_app_mod, "get_settings", return_value=settings), patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch( patch(
"messaging.factory.create_messaging_platform", "messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform, return_value=fake_platform,
), ),
patch("messaging.session.SessionStore", return_value=session_store), patch("messaging.session.SessionStore", return_value=session_store),
@ -336,7 +336,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
patch.object(api_app_mod, "get_settings", return_value=settings), patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider), patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch( patch(
"messaging.factory.create_messaging_platform", "messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform, return_value=fake_platform,
), ),
patch("messaging.session.SessionStore", return_value=session_store), patch("messaging.session.SessionStore", return_value=session_store),

View file

@ -25,18 +25,6 @@ def test_messages_request_map_model_claude_to_default(mock_settings):
assert request.original_model == "claude-3-opus" assert request.original_model == "claude-3-opus"
def test_messages_request_map_model_non_claude_unchanged(mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(
model="gpt-4",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
# normalize_model_name returns original if not Claude
assert request.model == "gpt-4"
def test_messages_request_map_model_with_provider_prefix(mock_settings): def test_messages_request_map_model_with_provider_prefix(mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings): with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest( request = MessagesRequest(

View file

@ -112,68 +112,42 @@ class TestQuotaCheckRequest:
class TestTitleGenerationRequest: class TestTitleGenerationRequest:
"""Tests for is_title_generation_request function.""" """Tests for is_title_generation_request function."""
def test_title_generation_simple(self): def _title_gen_system(self) -> list[MagicMock]:
"""Test title generation detection with target phrase."""
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = "Please write a 5-10 word title for this conversation"
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
assert is_title_generation_request(req) is True
def test_title_generation_case_insensitive(self):
"""Test title generation is case insensitive."""
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = "Write a 5-10 Word Title please"
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
assert is_title_generation_request(req) is True
def test_title_generation_list_content(self):
"""Test title generation with list content blocks."""
block = MagicMock() block = MagicMock()
block.text = "Write a 5-10 word title" block.text = "Analyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title."
return [block]
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = [block]
def test_title_generation_detected_via_system(self):
"""Title gen detected by system prompt containing topic/title keywords."""
req = MagicMock(spec=MessagesRequest) req = MagicMock(spec=MessagesRequest)
req.messages = [msg] req.system = self._title_gen_system()
req.tools = None
assert is_title_generation_request(req) is True assert is_title_generation_request(req) is True
def test_not_title_generation_no_phrase(self): def test_title_generation_not_detected_with_tools(self):
"""Test not title generation without target phrase.""" """Not detected when tools are present (main conversation, not title gen)."""
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = "Hello world, how are you?"
req = MagicMock(spec=MessagesRequest) req = MagicMock(spec=MessagesRequest)
req.messages = [msg] req.system = self._title_gen_system()
req.tools = [MagicMock()]
assert is_title_generation_request(req) is False assert is_title_generation_request(req) is False
def test_not_title_generation_wrong_role(self): def test_title_generation_not_detected_no_system(self):
"""Test not title generation when last message is not from user.""" """Not detected when system is absent."""
msg = MagicMock(spec=Message)
msg.role = "assistant"
msg.content = "Write a 5-10 word title"
req = MagicMock(spec=MessagesRequest) req = MagicMock(spec=MessagesRequest)
req.messages = [msg] req.system = None
req.tools = None
assert is_title_generation_request(req) is False assert is_title_generation_request(req) is False
def test_not_title_generation_empty_messages(self): def test_title_generation_not_detected_unrelated_system(self):
"""Test not title generation when no messages.""" """Not detected when system prompt has no topic/title keywords."""
block = MagicMock()
block.text = "You are a helpful assistant."
req = MagicMock(spec=MessagesRequest) req = MagicMock(spec=MessagesRequest)
req.messages = [] req.system = [block]
req.tools = None
assert is_title_generation_request(req) is False assert is_title_generation_request(req) is False

View file

@ -18,8 +18,12 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from config.nim import NimSettings from config.nim import NimSettings
from messaging.base import CLISession, MessagingPlatform, SessionManagerInterface
from messaging.models import IncomingMessage from messaging.models import IncomingMessage
from messaging.platforms.base import (
CLISession,
MessagingPlatform,
SessionManagerInterface,
)
from messaging.session import SessionStore from messaging.session import SessionStore
from providers.base import ProviderConfig from providers.base import ProviderConfig
from providers.nvidia_nim import NvidiaNimProvider from providers.nvidia_nim import NvidiaNimProvider

View file

@ -1,6 +1,6 @@
"""Tests for messaging/discord_markdown.py.""" """Tests for messaging/rendering/discord_markdown.py."""
from messaging.discord_markdown import ( from messaging.rendering.discord_markdown import (
_is_gfm_table_header_line, _is_gfm_table_header_line,
_normalize_gfm_tables, _normalize_gfm_tables,
discord_bold, discord_bold,

View file

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from messaging.discord import ( from messaging.platforms.discord import (
DISCORD_AVAILABLE, DISCORD_AVAILABLE,
DiscordPlatform, DiscordPlatform,
_get_discord, _get_discord,

View file

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from messaging.telegram_markdown import ( from messaging.rendering.telegram_markdown import (
escape_md_v2, escape_md_v2,
escape_md_v2_code, escape_md_v2_code,
mdv2_bold, mdv2_bold,

View file

@ -4,7 +4,7 @@ import pytest
from messaging.handler import ClaudeMessageHandler from messaging.handler import ClaudeMessageHandler
from messaging.models import IncomingMessage from messaging.models import IncomingMessage
from messaging.telegram_markdown import render_markdown_to_mdv2 from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
from messaging.trees.data import MessageNode, MessageState from messaging.trees.data import MessageNode, MessageState

View file

@ -49,7 +49,7 @@ class TestMessagingBase:
def test_platform_is_abstract(self): def test_platform_is_abstract(self):
"""Verify MessagingPlatform cannot be instantiated.""" """Verify MessagingPlatform cannot be instantiated."""
from messaging.base import MessagingPlatform from messaging.platforms.base import MessagingPlatform
with pytest.raises(TypeError): with pytest.raises(TypeError):
MessagingPlatform() MessagingPlatform()

View file

@ -2,7 +2,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from messaging.factory import create_messaging_platform from messaging.platforms.factory import create_messaging_platform
class TestCreateMessagingPlatform: class TestCreateMessagingPlatform:

View file

@ -3,12 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from telegram.error import NetworkError, RetryAfter, TelegramError from telegram.error import NetworkError, RetryAfter, TelegramError
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
@pytest.fixture @pytest.fixture
def telegram_platform(): def telegram_platform():
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345") platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
return platform return platform
@ -76,7 +76,7 @@ async def test_telegram_no_retry_on_bad_request(telegram_platform):
def test_handler_build_message_hardening(): def test_handler_build_message_hardening():
# Formatting hardening now lives in TranscriptBuffer rendering. # Formatting hardening now lives in TranscriptBuffer rendering.
from messaging.telegram_markdown import ( from messaging.rendering.telegram_markdown import (
escape_md_v2, escape_md_v2,
escape_md_v2_code, escape_md_v2_code,
mdv2_bold, mdv2_bold,
@ -112,7 +112,7 @@ def test_handler_build_message_hardening():
def test_render_output_never_exceeds_4096(): def test_render_output_never_exceeds_4096():
"""Transcript render with various status lengths never exceeds Telegram 4096 limit.""" """Transcript render with various status lengths never exceeds Telegram 4096 limit."""
from messaging.telegram_markdown import ( from messaging.rendering.telegram_markdown import (
escape_md_v2, escape_md_v2,
escape_md_v2_code, escape_md_v2_code,
mdv2_bold, mdv2_bold,

View file

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from messaging.telegram_markdown import ( from messaging.rendering.telegram_markdown import (
escape_md_v2, escape_md_v2,
escape_md_v2_code, escape_md_v2_code,
mdv2_bold, mdv2_bold,
@ -74,7 +74,7 @@ def test_empty_components_with_status(handler):
def test_render_markdown_unclosed_markdown(): def test_render_markdown_unclosed_markdown():
"""Malformed markdown (e.g. unclosed *) does not crash and produces acceptable output.""" """Malformed markdown (e.g. unclosed *) does not crash and produces acceptable output."""
from messaging.telegram_markdown import render_markdown_to_mdv2 from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
md = "*bold without close" md = "*bold without close"
out = render_markdown_to_mdv2(md) out = render_markdown_to_mdv2(md)
@ -84,7 +84,7 @@ def test_render_markdown_unclosed_markdown():
def test_escape_md_v2_unicode_emoji(): def test_escape_md_v2_unicode_emoji():
"""Unicode and emoji pass through correctly (no special char escaping needed).""" """Unicode and emoji pass through correctly (no special char escaping needed)."""
from messaging.telegram_markdown import escape_md_v2, escape_md_v2_code from messaging.rendering.telegram_markdown import escape_md_v2, escape_md_v2_code
text = "Hello 世界 🎉 café" text = "Hello 世界 🎉 café"
assert escape_md_v2(text) == text assert escape_md_v2(text) == text

View file

@ -81,11 +81,13 @@ class TestSessionStoreSaveEdgeCases:
"""Tests for save failure handling.""" """Tests for save failure handling."""
def test_save_io_error_handled(self, tmp_store): def test_save_io_error_handled(self, tmp_store):
"""Write failure in _save() is logged but doesn't raise.""" """Write failure in _write_data() raises (callers handle the error)."""
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
with patch("builtins.open", side_effect=OSError("disk full")): with (
tmp_store._save() patch("builtins.open", side_effect=OSError("disk full")),
# Should not raise pytest.raises(OSError),
):
tmp_store._write_data(tmp_store._snapshot())
class TestSessionStoreClearAll: class TestSessionStoreClearAll:

View file

@ -2,12 +2,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
@pytest.fixture @pytest.fixture
def telegram_platform(): def telegram_platform():
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345") platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
return platform return platform

View file

@ -3,11 +3,12 @@ from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from telegram.error import NetworkError, RetryAfter, TelegramError
def test_telegram_platform_init_raises_when_dependency_missing(): def test_telegram_platform_init_raises_when_dependency_missing():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", False): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", False):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
with pytest.raises(ImportError): with pytest.raises(ImportError):
TelegramPlatform(bot_token="x") TelegramPlatform(bot_token="x")
@ -19,7 +20,7 @@ async def test_telegram_platform_start_requires_token():
patch.dict("os.environ", {}, clear=True), patch.dict("os.environ", {}, clear=True),
patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True), patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True),
): ):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token=None) platform = TelegramPlatform(bot_token=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -29,7 +30,7 @@ async def test_telegram_platform_start_requires_token():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_telegram_platform_stop_no_application_is_noop(): async def test_telegram_platform_stop_no_application_is_noop():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
platform._application = None platform._application = None
@ -41,7 +42,7 @@ async def test_telegram_platform_stop_no_application_is_noop():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_retry_returns_none_when_message_not_modified_network_error(): async def test_with_retry_returns_none_when_message_not_modified_network_error():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import NetworkError, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
@ -54,7 +55,7 @@ async def test_with_retry_returns_none_when_message_not_modified_network_error()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_retry_retries_network_error_then_succeeds(monkeypatch): async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import NetworkError, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
@ -75,7 +76,7 @@ async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_retry_honors_retry_after_timedelta(monkeypatch): async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import RetryAfter, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
@ -96,7 +97,7 @@ async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_retry_drops_parse_mode_on_markdown_entity_error(): async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramError, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
@ -115,7 +116,7 @@ async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queue_send_message_without_limiter_calls_send_message(): async def test_queue_send_message_without_limiter_calls_send_message():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
platform._limiter = None platform._limiter = None
@ -130,7 +131,7 @@ async def test_queue_send_message_without_limiter_calls_send_message():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queue_edit_message_without_limiter_calls_edit_message(): async def test_queue_edit_message_without_limiter_calls_edit_message():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
platform._limiter = None platform._limiter = None
@ -143,7 +144,7 @@ async def test_queue_edit_message_without_limiter_calls_edit_message():
def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch): def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
@ -157,7 +158,7 @@ def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_start_command_replies_and_forwards(): async def test_on_start_command_replies_and_forwards():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
with patch.object( with patch.object(
@ -174,7 +175,7 @@ async def test_on_start_command_replies_and_forwards():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_telegram_message_handler_error_sends_error_message(): async def test_on_telegram_message_handler_error_sends_error_message():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t", allowed_user_id="123") platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
with patch.object( with patch.object(
@ -200,7 +201,7 @@ async def test_on_telegram_message_handler_error_sends_error_message():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_telegram_start_retries_on_network_error(monkeypatch): async def test_telegram_start_retries_on_network_error(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import NetworkError, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="token", allowed_user_id=None) platform = TelegramPlatform(bot_token="token", allowed_user_id=None)
@ -225,7 +226,7 @@ async def test_telegram_start_retries_on_network_error(monkeypatch):
async def test_edit_message_with_text_exceeding_4096_raises(): async def test_edit_message_with_text_exceeding_4096_raises():
"""edit_message with text > 4096 raises TelegramError (BadRequest).""" """edit_message with text > 4096 raises TelegramError (BadRequest)."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramError, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
platform._application = MagicMock() platform._application = MagicMock()
@ -242,7 +243,7 @@ async def test_edit_message_with_text_exceeding_4096_raises():
async def test_edit_message_empty_string(): async def test_edit_message_empty_string():
"""edit_message with empty string - Telegram accepts (no-op edit).""" """edit_message with empty string - Telegram accepts (no-op edit)."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
platform._application = MagicMock() platform._application = MagicMock()
@ -259,7 +260,7 @@ async def test_edit_message_empty_string():
async def test_send_message_empty_string(): async def test_send_message_empty_string():
"""send_message with empty string - Telegram may reject; we pass through.""" """send_message with empty string - Telegram may reject; we pass through."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")
platform._application = MagicMock() platform._application = MagicMock()
@ -277,7 +278,7 @@ async def test_send_message_empty_string():
async def test_on_telegram_message_non_text_update_ignored(): async def test_on_telegram_message_non_text_update_ignored():
"""Update with message.photo but no text returns early without calling handler.""" """Update with message.photo but no text returns early without calling handler."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t", allowed_user_id="123") platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
handler = AsyncMock() handler = AsyncMock()
@ -299,7 +300,7 @@ async def test_on_telegram_message_non_text_update_ignored():
async def test_with_retry_message_not_found_returns_none(): async def test_with_retry_message_not_found_returns_none():
"""'message to edit not found' returns None without retry.""" """'message to edit not found' returns None without retry."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True): with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramError, TelegramPlatform from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t") platform = TelegramPlatform(bot_token="t")

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
from messaging.telegram_markdown import ( from messaging.rendering.telegram_markdown import (
escape_md_v2, escape_md_v2,
escape_md_v2_code, escape_md_v2_code,
mdv2_bold, mdv2_bold,

View file

@ -1,133 +0,0 @@
import pytest
from providers.model_utils import (
get_original_model,
is_claude_model,
normalize_model_name,
strip_provider_prefixes,
)
def test_strip_provider_prefixes():
assert strip_provider_prefixes("anthropic/claude-3") == "claude-3"
assert strip_provider_prefixes("openai/gpt-4") == "gpt-4"
assert strip_provider_prefixes("gemini/gemini-pro") == "gemini-pro"
assert strip_provider_prefixes("no-prefix") == "no-prefix"
def test_is_claude_model():
assert is_claude_model("claude-3-sonnet") is True
assert is_claude_model("claude-3-opus") is True
assert is_claude_model("claude-3-haiku") is True
assert is_claude_model("claude-2.1") is True
assert is_claude_model("gpt-4") is False
assert is_claude_model("gemini-pro") is False
def test_normalize_model_name_claude_maps_to_default():
default = "target-model"
# Strips prefix AND maps to default
assert normalize_model_name("anthropic/claude-3-sonnet", default) == default
assert normalize_model_name("claude-3-opus", default) == default
def test_normalize_model_name_non_claude_unchanged():
default = "target-model"
assert normalize_model_name("gpt-4", default) == "gpt-4"
assert (
normalize_model_name("openai/gpt-3.5-turbo", default) == "openai/gpt-3.5-turbo"
)
def test_get_original_model():
assert get_original_model("any-model") == "any-model"
def test_normalize_model_name_without_default(monkeypatch):
monkeypatch.setenv("MODEL", "env-default-model")
assert normalize_model_name("claude-3") == "env-default-model"
# --- Parametrized Edge Case Tests ---
@pytest.mark.parametrize(
"model,expected",
[
("anthropic/claude-3", "claude-3"),
("openai/gpt-4", "gpt-4"),
("gemini/gemini-pro", "gemini-pro"),
("no-prefix", "no-prefix"),
("", ""),
("anthropic/", ""),
("anthropic/openai/nested", "openai/nested"),
],
ids=[
"anthropic",
"openai",
"gemini",
"no_prefix",
"empty_string",
"prefix_only",
"nested_prefix",
],
)
def test_strip_provider_prefixes_parametrized(model, expected):
"""Parametrized prefix stripping with edge cases."""
assert strip_provider_prefixes(model) == expected
@pytest.mark.parametrize(
"model,expected",
[
("claude-3-sonnet", True),
("claude-3-opus", True),
("claude-3-haiku", True),
("claude-2.1", True),
("gpt-4", False),
("gemini-pro", False),
("", False),
("my-claude-wrapper", True), # "claude" as substring
("CLAUDE-3-SONNET", True), # case insensitive
("sonnet-v2", True), # "sonnet" identifier without "claude"
("haiku-model", True), # "haiku" identifier
],
ids=[
"sonnet",
"opus",
"haiku",
"claude2",
"gpt4",
"gemini",
"empty",
"claude_substring",
"uppercase",
"sonnet_standalone",
"haiku_standalone",
],
)
def test_is_claude_model_parametrized(model, expected):
"""Parametrized Claude model detection with edge cases."""
assert is_claude_model(model) is expected
@pytest.mark.parametrize(
"model,default,expected",
[
("claude-3-sonnet", "target", "target"),
("anthropic/claude-3-opus", "target", "target"),
("gpt-4", "target", "gpt-4"),
("openai/gpt-3.5-turbo", "target", "openai/gpt-3.5-turbo"),
("", "target", ""), # empty string is not a claude model
],
ids=[
"claude_mapped",
"prefixed_claude",
"non_claude",
"prefixed_non_claude",
"empty",
],
)
def test_normalize_model_name_parametrized(model, default, expected):
"""Parametrized model normalization."""
assert normalize_model_name(model, default) == expected

View file

@ -3,9 +3,9 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from config.nim import NimSettings from config.nim import NimSettings
from providers.common.utils import set_if_not_none
from providers.nvidia_nim.request import ( from providers.nvidia_nim.request import (
_set_extra, _set_extra,
_set_if_not_none,
build_request_body, build_request_body,
) )
@ -13,12 +13,12 @@ from providers.nvidia_nim.request import (
class TestSetIfNotNone: class TestSetIfNotNone:
def test_value_not_none_sets(self): def test_value_not_none_sets(self):
body = {} body = {}
_set_if_not_none(body, "key", "value") set_if_not_none(body, "key", "value")
assert body["key"] == "value" assert body["key"] == "value"
def test_value_none_skips(self): def test_value_none_skips(self):
body = {} body = {}
_set_if_not_none(body, "key", None) set_if_not_none(body, "key", None)
assert "key" not in body assert "key" not in body