mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-26 10:31:07 +00:00
Backup/before cleanup 20260222 230402 (#58)
This commit is contained in:
parent
e2840095ce
commit
c4d8681000
43 changed files with 253 additions and 584 deletions
|
|
@ -29,7 +29,7 @@ HTTP_CONNECT_TIMEOUT=2
|
|||
|
||||
|
||||
# Messaging Platform: "telegram" | "discord"
|
||||
MESSAGING_PLATFORM=discord
|
||||
MESSAGING_PLATFORM="discord"
|
||||
MESSAGING_RATE_LIMIT=1
|
||||
MESSAGING_RATE_WINDOW=1
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
## CODING ENVIRONMENT
|
||||
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
||||
- Always use `uv run` to run files instead of the global `python` command.
|
||||
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
|
||||
- Read `.env.example` for environment variables.
|
||||
- All CI checks must pass; failing checks block merge.
|
||||
- Add tests for new changes (including edge cases), then run `uv run pytest`.
|
||||
|
|
@ -38,7 +39,7 @@
|
|||
|
||||
## SUMMARY STANDARDS
|
||||
- Summaries must be technical and granular.
|
||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks].
|
||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
|
||||
|
||||
## TOOLS
|
||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
## CODING ENVIRONMENT
|
||||
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
||||
- Always use `uv run` to run files instead of the global `python` command.
|
||||
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
|
||||
- Read `.env.example` for environment variables.
|
||||
- All CI checks must pass; failing checks block merge.
|
||||
- Add tests for new changes (including edge cases), then run `uv run pytest`.
|
||||
|
|
@ -38,7 +39,7 @@
|
|||
|
||||
## SUMMARY STANDARDS
|
||||
- Summaries must be technical and granular.
|
||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks].
|
||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
|
||||
|
||||
## TOOLS
|
||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||
17
README.md
17
README.md
|
|
@ -289,11 +289,11 @@ uv sync --extra voice
|
|||
Full list in [`nvidia_nim_models.json`](nvidia_nim_models.json).
|
||||
|
||||
Popular models:
|
||||
- `qwen/qwen3.5-397b-a17b`
|
||||
- `z-ai/glm5`
|
||||
- `stepfun-ai/step-3.5-flash`
|
||||
- `moonshotai/kimi-k2.5`
|
||||
- `minimaxai/minimax-m2.1`
|
||||
- `nvidia_nim/minimaxai/minimax-m2.5`
|
||||
- `nvidia_nim/qwen/qwen3.5-397b-a17b`
|
||||
- `nvidia_nim/z-ai/glm5`
|
||||
- `nvidia_nim/stepfun-ai/step-3.5-flash`
|
||||
- `nvidia_nim/moonshotai/kimi-k2.5`
|
||||
|
||||
Browse: [build.nvidia.com](https://build.nvidia.com/explore/discover)
|
||||
|
||||
|
|
@ -310,9 +310,9 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
|
|||
Hundreds of models from StepFun, OpenAI, Anthropic, Google, and more.
|
||||
|
||||
Popular models:
|
||||
- `stepfun/step-3.5-flash:free`
|
||||
- `deepseek/deepseek-r1-0528:free`
|
||||
- `openai/gpt-oss-120b:free`
|
||||
- `open_router/stepfun/step-3.5-flash:free`
|
||||
- `open_router/deepseek/deepseek-r1-0528:free`
|
||||
- `open_router/openai/gpt-oss-120b:free`
|
||||
|
||||
Browse: [openrouter.ai/models](https://openrouter.ai/models)
|
||||
|
||||
|
|
@ -385,7 +385,6 @@ free-claude-code/
|
|||
├── messaging/ # MessagingPlatform ABC + Discord/Telegram bots, session management
|
||||
├── config/ # Settings, NIM config, logging
|
||||
├── cli/ # CLI session and process management
|
||||
├── utils/ # Text utilities
|
||||
└── tests/ # Pytest test suite
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
try:
|
||||
# Use the messaging factory to create the right platform
|
||||
from messaging.factory import create_messaging_platform
|
||||
from messaging.platforms.factory import create_messaging_platform
|
||||
|
||||
messaging_platform = create_messaging_platform(
|
||||
platform_type=settings.messaging_platform,
|
||||
|
|
|
|||
|
|
@ -16,88 +16,84 @@ def get_settings() -> Settings:
|
|||
return _get_settings()
|
||||
|
||||
|
||||
def _create_provider(settings: Settings) -> BaseProvider:
|
||||
"""Construct and return a new provider instance from settings."""
|
||||
if settings.provider_type == "nvidia_nim":
|
||||
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://build.nvidia.com/settings/api-keys"
|
||||
),
|
||||
)
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key=settings.nvidia_nim_api_key,
|
||||
base_url=NVIDIA_NIM_BASE_URL,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
elif settings.provider_type == "open_router":
|
||||
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://openrouter.ai/keys"
|
||||
),
|
||||
)
|
||||
from providers.open_router import OpenRouterProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key=settings.open_router_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = OpenRouterProvider(config)
|
||||
elif settings.provider_type == "lmstudio":
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url=settings.lm_studio_base_url,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = LMStudioProvider(config)
|
||||
else:
|
||||
logger.error(
|
||||
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
settings.provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{settings.provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
logger.info("Provider initialized: %s", settings.provider_type)
|
||||
return provider
|
||||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the provider instance based on settings.provider_type."""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
settings = get_settings()
|
||||
|
||||
if settings.provider_type == "nvidia_nim":
|
||||
if (
|
||||
not settings.nvidia_nim_api_key
|
||||
or not settings.nvidia_nim_api_key.strip()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://build.nvidia.com/settings/api-keys"
|
||||
),
|
||||
)
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key=settings.nvidia_nim_api_key,
|
||||
base_url=NVIDIA_NIM_BASE_URL,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
_provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
logger.info("Provider initialized: %s", settings.provider_type)
|
||||
elif settings.provider_type == "open_router":
|
||||
if (
|
||||
not settings.open_router_api_key
|
||||
or not settings.open_router_api_key.strip()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://openrouter.ai/keys"
|
||||
),
|
||||
)
|
||||
from providers.open_router import OpenRouterProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key=settings.open_router_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
_provider = OpenRouterProvider(config)
|
||||
logger.info("Provider initialized: %s", settings.provider_type)
|
||||
elif settings.provider_type == "lmstudio":
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url=settings.lm_studio_base_url,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
_provider = LMStudioProvider(config)
|
||||
logger.info("Provider initialized: %s", settings.provider_type)
|
||||
else:
|
||||
logger.error(
|
||||
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
settings.provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{settings.provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
_provider = _create_provider(get_settings())
|
||||
return _provider
|
||||
|
||||
|
||||
|
|
@ -108,5 +104,10 @@ async def cleanup_provider():
|
|||
client = getattr(_provider, "_client", None)
|
||||
if client and hasattr(client, "aclose"):
|
||||
await client.aclose()
|
||||
elif client:
|
||||
logger.warning(
|
||||
"Provider client %r has no aclose(); skipping async cleanup",
|
||||
type(client).__name__,
|
||||
)
|
||||
_provider = None
|
||||
logger.debug("Provider cleanup completed")
|
||||
|
|
|
|||
|
|
@ -29,14 +29,13 @@ def is_quota_check_request(request_data: MessagesRequest) -> bool:
|
|||
def is_title_generation_request(request_data: MessagesRequest) -> bool:
|
||||
"""Check if this is a conversation title generation request.
|
||||
|
||||
Title generation requests typically contain the phrase
|
||||
"write a 5-10 word title" in the user's message.
|
||||
Title generation requests are detected by a system prompt containing
|
||||
title extraction instructions, no tools, and a single user message.
|
||||
"""
|
||||
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user":
|
||||
text = extract_text_from_content(request_data.messages[-1].content)
|
||||
if "write a 5-10 word title" in text.lower():
|
||||
return True
|
||||
return False
|
||||
if not request_data.system or request_data.tools:
|
||||
return False
|
||||
system_text = extract_text_from_content(request_data.system).lower()
|
||||
return "new conversation topic" in system_text and "title" in system_text
|
||||
|
||||
|
||||
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from loguru import logger
|
|||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from config.settings import get_settings
|
||||
from providers.model_utils import normalize_model_name
|
||||
|
||||
# =============================================================================
|
||||
# Content Block Types
|
||||
|
|
@ -112,10 +111,7 @@ class MessagesRequest(BaseModel):
|
|||
if self.original_model is None:
|
||||
self.original_model = self.model
|
||||
|
||||
# Use centralized model normalization
|
||||
normalized = normalize_model_name(self.model, settings.model_name)
|
||||
if normalized != self.model:
|
||||
self.model = normalized
|
||||
self.model = settings.model_name
|
||||
|
||||
if self.model != self.original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||
|
|
@ -136,5 +132,4 @@ class TokenCountRequest(BaseModel):
|
|||
def validate_model_field(cls, v, info):
|
||||
"""Map any Claude model name to the configured model."""
|
||||
settings = get_settings()
|
||||
# Use centralized model normalization
|
||||
return normalize_model_name(v, settings.model_name)
|
||||
return settings.model_name
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
|
|||
format=_serialize_with_context,
|
||||
encoding="utf-8",
|
||||
mode="a",
|
||||
rotation="50 MB",
|
||||
)
|
||||
|
||||
# Intercept stdlib logging: route all root logger output to loguru
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""Platform-agnostic messaging layer."""
|
||||
|
||||
from .base import CLISession, MessagingPlatform, SessionManagerInterface
|
||||
from .event_parser import parse_cli_event
|
||||
from .handler import ClaudeMessageHandler
|
||||
from .models import IncomingMessage
|
||||
from .platforms.base import CLISession, MessagingPlatform, SessionManagerInterface
|
||||
from .session import SessionStore
|
||||
from .trees.data import MessageNode, MessageState, MessageTree
|
||||
from .trees.queue_manager import TreeQueueManager
|
||||
|
|
|
|||
|
|
@ -1,9 +0,0 @@
|
|||
"""Backward-compatible re-export. Use messaging.platforms.base for new code."""
|
||||
|
||||
from .platforms.base import (
|
||||
CLISession,
|
||||
MessagingPlatform,
|
||||
SessionManagerInterface,
|
||||
)
|
||||
|
||||
__all__ = ["CLISession", "MessagingPlatform", "SessionManagerInterface"]
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
"""Backward-compatible re-export. Use messaging.platforms.discord for new code."""
|
||||
|
||||
from .platforms.discord import (
|
||||
DISCORD_AVAILABLE,
|
||||
DISCORD_MESSAGE_LIMIT,
|
||||
DiscordPlatform,
|
||||
_get_discord,
|
||||
_parse_allowed_channels,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DISCORD_AVAILABLE",
|
||||
"DISCORD_MESSAGE_LIMIT",
|
||||
"DiscordPlatform",
|
||||
"_get_discord",
|
||||
"_parse_allowed_channels",
|
||||
]
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
"""Backward-compatible re-export. Use messaging.rendering.discord_markdown for new code."""
|
||||
|
||||
from .rendering.discord_markdown import (
|
||||
_is_gfm_table_header_line,
|
||||
_normalize_gfm_tables,
|
||||
discord_bold,
|
||||
discord_code_inline,
|
||||
escape_discord,
|
||||
escape_discord_code,
|
||||
format_status,
|
||||
format_status_discord,
|
||||
render_markdown_to_discord,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"_is_gfm_table_header_line",
|
||||
"_normalize_gfm_tables",
|
||||
"discord_bold",
|
||||
"discord_code_inline",
|
||||
"escape_discord",
|
||||
"escape_discord_code",
|
||||
"format_status",
|
||||
"format_status_discord",
|
||||
"render_markdown_to_discord",
|
||||
]
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
"""Backward-compatible re-export. Use messaging.platforms.factory for new code."""
|
||||
|
||||
from .platforms.factory import create_messaging_platform
|
||||
|
||||
__all__ = ["create_messaging_platform"]
|
||||
|
|
@ -95,18 +95,18 @@ class SessionStore:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to load sessions: {e}")
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Persist sessions and trees to disk. Caller must hold self._lock."""
|
||||
try:
|
||||
data = {
|
||||
"trees": self._trees,
|
||||
"node_to_tree": self._node_to_tree,
|
||||
"message_log": self._message_log,
|
||||
}
|
||||
with open(self.storage_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save sessions: {e}")
|
||||
def _snapshot(self) -> dict:
|
||||
"""Snapshot current state for serialization. Caller must hold self._lock."""
|
||||
return {
|
||||
"trees": dict(self._trees),
|
||||
"node_to_tree": dict(self._node_to_tree),
|
||||
"message_log": {k: list(v) for k, v in self._message_log.items()},
|
||||
}
|
||||
|
||||
def _write_data(self, data: dict) -> None:
|
||||
"""Write data dict to disk. Must be called WITHOUT holding self._lock."""
|
||||
with open(self.storage_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def _schedule_save(self) -> None:
|
||||
"""Schedule a debounced save. Caller must hold self._lock."""
|
||||
|
|
@ -126,22 +126,35 @@ class SessionStore:
|
|||
if not self._dirty:
|
||||
self._save_timer = None
|
||||
return
|
||||
self._save()
|
||||
snapshot = self._snapshot()
|
||||
self._dirty = False
|
||||
self._save_timer = None
|
||||
try:
|
||||
self._write_data(snapshot)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save sessions: {e}")
|
||||
with self._lock:
|
||||
self._dirty = True
|
||||
|
||||
def _flush_save(self) -> None:
|
||||
"""Immediate save, cancel any pending debounced save. Caller must hold self._lock."""
|
||||
def _flush_save(self) -> dict:
|
||||
"""Cancel pending timer and snapshot current state. Caller must hold self._lock.
|
||||
Returns snapshot dict; caller must call _write_data(snapshot) outside the lock."""
|
||||
if self._save_timer is not None:
|
||||
self._save_timer.cancel()
|
||||
self._save_timer = None
|
||||
self._dirty = False
|
||||
self._save()
|
||||
return self._snapshot()
|
||||
|
||||
def flush_pending_save(self) -> None:
|
||||
"""Flush any pending debounced save. Call on shutdown to avoid losing data."""
|
||||
with self._lock:
|
||||
self._flush_save()
|
||||
snapshot = self._flush_save()
|
||||
try:
|
||||
self._write_data(snapshot)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save sessions: {e}")
|
||||
with self._lock:
|
||||
self._dirty = True
|
||||
|
||||
def record_message_id(
|
||||
self,
|
||||
|
|
@ -201,7 +214,13 @@ class SessionStore:
|
|||
self._node_to_tree.clear()
|
||||
self._message_log.clear()
|
||||
self._message_log_ids.clear()
|
||||
self._flush_save()
|
||||
snapshot = self._flush_save()
|
||||
try:
|
||||
self._write_data(snapshot)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save sessions: {e}")
|
||||
with self._lock:
|
||||
self._dirty = True
|
||||
|
||||
# ==================== Tree Methods ====================
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
"""Backward-compatible re-export. Use messaging.platforms.telegram for new code."""
|
||||
|
||||
from .platforms.telegram import (
|
||||
TELEGRAM_AVAILABLE,
|
||||
TelegramPlatform,
|
||||
)
|
||||
|
||||
# Re-export telegram.error types when python-telegram-bot is installed
|
||||
__all__ = ["TELEGRAM_AVAILABLE", "TelegramPlatform"]
|
||||
try:
|
||||
from telegram.error import NetworkError, RetryAfter, TelegramError
|
||||
|
||||
__all__ += ["NetworkError", "RetryAfter", "TelegramError"]
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
"""Backward-compatible re-export. Use messaging.rendering.telegram_markdown for new code."""
|
||||
|
||||
from .rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
escape_md_v2_link_url,
|
||||
format_status,
|
||||
mdv2_bold,
|
||||
mdv2_code_inline,
|
||||
render_markdown_to_mdv2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"escape_md_v2",
|
||||
"escape_md_v2_code",
|
||||
"escape_md_v2_link_url",
|
||||
"format_status",
|
||||
"mdv2_bold",
|
||||
"mdv2_code_inline",
|
||||
"render_markdown_to_mdv2",
|
||||
]
|
||||
|
|
@ -105,17 +105,19 @@ class TreeRepository:
|
|||
if not tree:
|
||||
return []
|
||||
|
||||
pending = []
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
pending: list[MessageNode] = []
|
||||
stack = [node_id]
|
||||
|
||||
for child_id in node.children_ids:
|
||||
child = tree.get_node(child_id)
|
||||
if child and child.state == MessageState.PENDING:
|
||||
pending.append(child)
|
||||
# Recursively get children of pending children
|
||||
pending.extend(self.get_pending_children(child_id))
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
node = tree.get_node(current_id)
|
||||
if not node:
|
||||
continue
|
||||
for child_id in node.children_ids:
|
||||
child = tree.get_node(child_id)
|
||||
if child and child.state == MessageState.PENDING:
|
||||
pending.append(child)
|
||||
stack.append(child_id)
|
||||
|
||||
return pending
|
||||
|
||||
|
|
|
|||
|
|
@ -476,13 +476,13 @@
|
|||
"owned_by": "microsoft"
|
||||
},
|
||||
{
|
||||
"id": "minimaxai/minimax-m2",
|
||||
"id": "minimaxai/minimax-m2.1",
|
||||
"object": "model",
|
||||
"created": 735790403,
|
||||
"owned_by": "minimaxai"
|
||||
},
|
||||
{
|
||||
"id": "minimaxai/minimax-m2.1",
|
||||
"id": "minimaxai/minimax-m2.5",
|
||||
"object": "model",
|
||||
"created": 735790403,
|
||||
"owned_by": "minimaxai"
|
||||
|
|
@ -709,12 +709,6 @@
|
|||
"created": 735790403,
|
||||
"owned_by": "nvidia"
|
||||
},
|
||||
{
|
||||
"id": "nvidia/llama-3.2-nemoretriever-300m-embed-v2",
|
||||
"object": "model",
|
||||
"created": 735790403,
|
||||
"owned_by": "nvidia"
|
||||
},
|
||||
{
|
||||
"id": "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
||||
"object": "model",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from .message_converter import (
|
|||
)
|
||||
from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason
|
||||
from .think_parser import ContentChunk, ContentType, ThinkTagParser
|
||||
from .utils import set_if_not_none
|
||||
|
||||
__all__ = [
|
||||
"AnthropicToOpenAIConverter",
|
||||
|
|
@ -22,4 +23,5 @@ __all__ = [
|
|||
"get_block_type",
|
||||
"map_error",
|
||||
"map_stop_reason",
|
||||
"set_if_not_none",
|
||||
]
|
||||
|
|
|
|||
9
providers/common/utils.py
Normal file
9
providers/common/utils.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Shared utility helpers for provider request builders."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
||||
"""Set body[key] = value only when value is not None."""
|
||||
if value is not None:
|
||||
body[key] = value
|
||||
|
|
@ -5,15 +5,11 @@ from typing import Any
|
|||
from loguru import logger
|
||||
|
||||
from providers.common.message_converter import AnthropicToOpenAIConverter
|
||||
from providers.common.utils import set_if_not_none
|
||||
|
||||
LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
|
||||
|
||||
|
||||
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
body[key] = value
|
||||
|
||||
|
||||
def build_request_body(request_data: Any) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for LM Studio."""
|
||||
logger.debug(
|
||||
|
|
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
|
|||
}
|
||||
|
||||
max_tokens = getattr(request_data, "max_tokens", None)
|
||||
_set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
|
||||
set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
|
||||
|
||||
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
||||
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
||||
set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
||||
set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
||||
|
||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||
if stop_sequences:
|
||||
|
|
|
|||
|
|
@ -1,86 +0,0 @@
|
|||
"""Model name normalization utilities.
|
||||
|
||||
Centralizes model name mapping logic to avoid duplication across the codebase.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Provider prefixes to strip from model names
|
||||
_PROVIDER_PREFIXES = ["anthropic/", "openai/", "gemini/"]
|
||||
|
||||
# Claude model identifiers
|
||||
_CLAUDE_IDENTIFIERS = ["haiku", "sonnet", "opus", "claude"]
|
||||
|
||||
|
||||
def strip_provider_prefixes(model: str) -> str:
|
||||
"""
|
||||
Strip provider prefixes from model name.
|
||||
|
||||
Args:
|
||||
model: The model name, possibly with prefix
|
||||
|
||||
Returns:
|
||||
Model name without provider prefix
|
||||
"""
|
||||
for prefix in _PROVIDER_PREFIXES:
|
||||
if model.startswith(prefix):
|
||||
return model[len(prefix) :]
|
||||
return model
|
||||
|
||||
|
||||
def is_claude_model(model: str) -> bool:
|
||||
"""
|
||||
Check if a model name identifies as a Claude model.
|
||||
|
||||
Args:
|
||||
model: The (prefix-stripped) model name
|
||||
|
||||
Returns:
|
||||
True if this is a Claude model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return any(name in model_lower for name in _CLAUDE_IDENTIFIERS)
|
||||
|
||||
|
||||
def normalize_model_name(model: str, default_model: str | None = None) -> str:
|
||||
"""
|
||||
Normalize a model name by stripping prefixes and mapping to default if needed.
|
||||
|
||||
This is the central function for model name normalization across the API.
|
||||
It strips provider prefixes and maps Claude model names to the configured model.
|
||||
|
||||
Args:
|
||||
model: The model name (may include provider prefix)
|
||||
default_model: The default model to use for Claude models.
|
||||
If None, uses settings.model from config.
|
||||
|
||||
Returns:
|
||||
Normalized model name (original if not a Claude model, mapped if Claude)
|
||||
"""
|
||||
# Strip provider prefixes
|
||||
clean = strip_provider_prefixes(model)
|
||||
|
||||
# Map Claude models to default
|
||||
if is_claude_model(clean):
|
||||
if default_model is None:
|
||||
# Use environment/config default
|
||||
default_model = os.getenv("MODEL", "moonshotai/kimi-k2-thinking")
|
||||
return default_model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_original_model(model: str) -> str:
|
||||
"""
|
||||
Get the original model name, storing it before normalization.
|
||||
|
||||
Convenience function that returns the input unchanged, intended to be
|
||||
called alongside normalize_model_name to capture the original.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
The model name unchanged (for documentation purposes)
|
||||
"""
|
||||
return model
|
||||
|
|
@ -6,11 +6,7 @@ from loguru import logger
|
|||
|
||||
from config.nim import NimSettings
|
||||
from providers.common.message_converter import AnthropicToOpenAIConverter
|
||||
|
||||
|
||||
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
body[key] = value
|
||||
from providers.common.utils import set_if_not_none
|
||||
|
||||
|
||||
def _set_extra(
|
||||
|
|
@ -52,15 +48,15 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
|
|||
max_tokens = nim.max_tokens
|
||||
elif nim.max_tokens:
|
||||
max_tokens = min(max_tokens, nim.max_tokens)
|
||||
_set_if_not_none(body, "max_tokens", max_tokens)
|
||||
set_if_not_none(body, "max_tokens", max_tokens)
|
||||
|
||||
req_temperature = getattr(request_data, "temperature", None)
|
||||
temperature = req_temperature if req_temperature is not None else nim.temperature
|
||||
_set_if_not_none(body, "temperature", temperature)
|
||||
set_if_not_none(body, "temperature", temperature)
|
||||
|
||||
req_top_p = getattr(request_data, "top_p", None)
|
||||
top_p = req_top_p if req_top_p is not None else nim.top_p
|
||||
_set_if_not_none(body, "top_p", top_p)
|
||||
set_if_not_none(body, "top_p", top_p)
|
||||
|
||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||
if stop_sequences:
|
||||
|
|
|
|||
|
|
@ -5,15 +5,11 @@ from typing import Any
|
|||
from loguru import logger
|
||||
|
||||
from providers.common.message_converter import AnthropicToOpenAIConverter
|
||||
from providers.common.utils import set_if_not_none
|
||||
|
||||
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
|
||||
|
||||
|
||||
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
body[key] = value
|
||||
|
||||
|
||||
def build_request_body(request_data: Any) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
|
||||
logger.debug(
|
||||
|
|
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
|
|||
}
|
||||
|
||||
max_tokens = getattr(request_data, "max_tokens", None)
|
||||
_set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
|
||||
set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
|
||||
|
||||
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
||||
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
||||
set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
||||
set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
||||
|
||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||
if stop_sequences:
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
|||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch(
|
||||
"messaging.factory.create_messaging_platform",
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform if messaging_enabled else None,
|
||||
) as create_platform,
|
||||
patch("messaging.session.SessionStore", return_value=session_store),
|
||||
|
|
@ -195,7 +195,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
|||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch(
|
||||
"messaging.factory.create_messaging_platform",
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform,
|
||||
),
|
||||
patch("messaging.session.SessionStore", return_value=session_store),
|
||||
|
|
@ -234,7 +234,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
|||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch(
|
||||
"messaging.factory.create_messaging_platform",
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
side_effect=ImportError("discord not installed"),
|
||||
),
|
||||
TestClient(app),
|
||||
|
|
@ -284,7 +284,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
|||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch(
|
||||
"messaging.factory.create_messaging_platform",
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform,
|
||||
),
|
||||
patch("messaging.session.SessionStore", return_value=session_store),
|
||||
|
|
@ -336,7 +336,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
|||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||
patch(
|
||||
"messaging.factory.create_messaging_platform",
|
||||
"messaging.platforms.factory.create_messaging_platform",
|
||||
return_value=fake_platform,
|
||||
),
|
||||
patch("messaging.session.SessionStore", return_value=session_store),
|
||||
|
|
|
|||
|
|
@ -25,18 +25,6 @@ def test_messages_request_map_model_claude_to_default(mock_settings):
|
|||
assert request.original_model == "claude-3-opus"
|
||||
|
||||
|
||||
def test_messages_request_map_model_non_claude_unchanged(mock_settings):
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = MessagesRequest(
|
||||
model="gpt-4",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
|
||||
# normalize_model_name returns original if not Claude
|
||||
assert request.model == "gpt-4"
|
||||
|
||||
|
||||
def test_messages_request_map_model_with_provider_prefix(mock_settings):
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = MessagesRequest(
|
||||
|
|
|
|||
|
|
@ -112,68 +112,42 @@ class TestQuotaCheckRequest:
|
|||
class TestTitleGenerationRequest:
|
||||
"""Tests for is_title_generation_request function."""
|
||||
|
||||
def test_title_generation_simple(self):
|
||||
"""Test title generation detection with target phrase."""
|
||||
msg = MagicMock(spec=Message)
|
||||
msg.role = "user"
|
||||
msg.content = "Please write a 5-10 word title for this conversation"
|
||||
|
||||
req = MagicMock(spec=MessagesRequest)
|
||||
req.messages = [msg]
|
||||
|
||||
assert is_title_generation_request(req) is True
|
||||
|
||||
def test_title_generation_case_insensitive(self):
|
||||
"""Test title generation is case insensitive."""
|
||||
msg = MagicMock(spec=Message)
|
||||
msg.role = "user"
|
||||
msg.content = "Write a 5-10 Word Title please"
|
||||
|
||||
req = MagicMock(spec=MessagesRequest)
|
||||
req.messages = [msg]
|
||||
|
||||
assert is_title_generation_request(req) is True
|
||||
|
||||
def test_title_generation_list_content(self):
|
||||
"""Test title generation with list content blocks."""
|
||||
def _title_gen_system(self) -> list[MagicMock]:
|
||||
block = MagicMock()
|
||||
block.text = "Write a 5-10 word title"
|
||||
|
||||
msg = MagicMock(spec=Message)
|
||||
msg.role = "user"
|
||||
msg.content = [block]
|
||||
block.text = "Analyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title."
|
||||
return [block]
|
||||
|
||||
def test_title_generation_detected_via_system(self):
|
||||
"""Title gen detected by system prompt containing topic/title keywords."""
|
||||
req = MagicMock(spec=MessagesRequest)
|
||||
req.messages = [msg]
|
||||
req.system = self._title_gen_system()
|
||||
req.tools = None
|
||||
|
||||
assert is_title_generation_request(req) is True
|
||||
|
||||
def test_not_title_generation_no_phrase(self):
|
||||
"""Test not title generation without target phrase."""
|
||||
msg = MagicMock(spec=Message)
|
||||
msg.role = "user"
|
||||
msg.content = "Hello world, how are you?"
|
||||
|
||||
def test_title_generation_not_detected_with_tools(self):
|
||||
"""Not detected when tools are present (main conversation, not title gen)."""
|
||||
req = MagicMock(spec=MessagesRequest)
|
||||
req.messages = [msg]
|
||||
req.system = self._title_gen_system()
|
||||
req.tools = [MagicMock()]
|
||||
|
||||
assert is_title_generation_request(req) is False
|
||||
|
||||
def test_not_title_generation_wrong_role(self):
|
||||
"""Test not title generation when last message is not from user."""
|
||||
msg = MagicMock(spec=Message)
|
||||
msg.role = "assistant"
|
||||
msg.content = "Write a 5-10 word title"
|
||||
|
||||
def test_title_generation_not_detected_no_system(self):
|
||||
"""Not detected when system is absent."""
|
||||
req = MagicMock(spec=MessagesRequest)
|
||||
req.messages = [msg]
|
||||
req.system = None
|
||||
req.tools = None
|
||||
|
||||
assert is_title_generation_request(req) is False
|
||||
|
||||
def test_not_title_generation_empty_messages(self):
|
||||
"""Test not title generation when no messages."""
|
||||
def test_title_generation_not_detected_unrelated_system(self):
|
||||
"""Not detected when system prompt has no topic/title keywords."""
|
||||
block = MagicMock()
|
||||
block.text = "You are a helpful assistant."
|
||||
req = MagicMock(spec=MessagesRequest)
|
||||
req.messages = []
|
||||
req.system = [block]
|
||||
req.tools = None
|
||||
|
||||
assert is_title_generation_request(req) is False
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,12 @@ from typing import Any
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from config.nim import NimSettings
|
||||
from messaging.base import CLISession, MessagingPlatform, SessionManagerInterface
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.platforms.base import (
|
||||
CLISession,
|
||||
MessagingPlatform,
|
||||
SessionManagerInterface,
|
||||
)
|
||||
from messaging.session import SessionStore
|
||||
from providers.base import ProviderConfig
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for messaging/discord_markdown.py."""
|
||||
"""Tests for messaging/rendering/discord_markdown.py."""
|
||||
|
||||
from messaging.discord_markdown import (
|
||||
from messaging.rendering.discord_markdown import (
|
||||
_is_gfm_table_header_line,
|
||||
_normalize_gfm_tables,
|
||||
discord_bold,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from messaging.discord import (
|
||||
from messaging.platforms.discord import (
|
||||
DISCORD_AVAILABLE,
|
||||
DiscordPlatform,
|
||||
_get_discord,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from messaging.telegram_markdown import (
|
||||
from messaging.rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
mdv2_bold,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
|
||||
from messaging.handler import ClaudeMessageHandler
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.telegram_markdown import render_markdown_to_mdv2
|
||||
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
|
||||
from messaging.trees.data import MessageNode, MessageState
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class TestMessagingBase:
|
|||
|
||||
def test_platform_is_abstract(self):
|
||||
"""Verify MessagingPlatform cannot be instantiated."""
|
||||
from messaging.base import MessagingPlatform
|
||||
from messaging.platforms.base import MessagingPlatform
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MessagingPlatform()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from messaging.factory import create_messaging_platform
|
||||
from messaging.platforms.factory import create_messaging_platform
|
||||
|
||||
|
||||
class TestCreateMessagingPlatform:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import pytest
|
||||
from telegram.error import NetworkError, RetryAfter, TelegramError
|
||||
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telegram_platform():
|
||||
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
|
||||
return platform
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ async def test_telegram_no_retry_on_bad_request(telegram_platform):
|
|||
|
||||
def test_handler_build_message_hardening():
|
||||
# Formatting hardening now lives in TranscriptBuffer rendering.
|
||||
from messaging.telegram_markdown import (
|
||||
from messaging.rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
mdv2_bold,
|
||||
|
|
@ -112,7 +112,7 @@ def test_handler_build_message_hardening():
|
|||
|
||||
def test_render_output_never_exceeds_4096():
|
||||
"""Transcript render with various status lengths never exceeds Telegram 4096 limit."""
|
||||
from messaging.telegram_markdown import (
|
||||
from messaging.rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
mdv2_bold,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from messaging.telegram_markdown import (
|
||||
from messaging.rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
mdv2_bold,
|
||||
|
|
@ -74,7 +74,7 @@ def test_empty_components_with_status(handler):
|
|||
|
||||
def test_render_markdown_unclosed_markdown():
|
||||
"""Malformed markdown (e.g. unclosed *) does not crash and produces acceptable output."""
|
||||
from messaging.telegram_markdown import render_markdown_to_mdv2
|
||||
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
|
||||
|
||||
md = "*bold without close"
|
||||
out = render_markdown_to_mdv2(md)
|
||||
|
|
@ -84,7 +84,7 @@ def test_render_markdown_unclosed_markdown():
|
|||
|
||||
def test_escape_md_v2_unicode_emoji():
|
||||
"""Unicode and emoji pass through correctly (no special char escaping needed)."""
|
||||
from messaging.telegram_markdown import escape_md_v2, escape_md_v2_code
|
||||
from messaging.rendering.telegram_markdown import escape_md_v2, escape_md_v2_code
|
||||
|
||||
text = "Hello 世界 🎉 café"
|
||||
assert escape_md_v2(text) == text
|
||||
|
|
|
|||
|
|
@ -81,11 +81,13 @@ class TestSessionStoreSaveEdgeCases:
|
|||
"""Tests for save failure handling."""
|
||||
|
||||
def test_save_io_error_handled(self, tmp_store):
|
||||
"""Write failure in _save() is logged but doesn't raise."""
|
||||
"""Write failure in _write_data() raises (callers handle the error)."""
|
||||
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
|
||||
with patch("builtins.open", side_effect=OSError("disk full")):
|
||||
tmp_store._save()
|
||||
# Should not raise
|
||||
with (
|
||||
patch("builtins.open", side_effect=OSError("disk full")),
|
||||
pytest.raises(OSError),
|
||||
):
|
||||
tmp_store._write_data(tmp_store._snapshot())
|
||||
|
||||
|
||||
class TestSessionStoreClearAll:
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telegram_platform():
|
||||
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
|
||||
return platform
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ from datetime import timedelta
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from telegram.error import NetworkError, RetryAfter, TelegramError
|
||||
|
||||
|
||||
def test_telegram_platform_init_raises_when_dependency_missing():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", False):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
with pytest.raises(ImportError):
|
||||
TelegramPlatform(bot_token="x")
|
||||
|
|
@ -19,7 +20,7 @@ async def test_telegram_platform_start_requires_token():
|
|||
patch.dict("os.environ", {}, clear=True),
|
||||
patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True),
|
||||
):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token=None)
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -29,7 +30,7 @@ async def test_telegram_platform_start_requires_token():
|
|||
@pytest.mark.asyncio
|
||||
async def test_telegram_platform_stop_no_application_is_noop():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
platform._application = None
|
||||
|
|
@ -41,7 +42,7 @@ async def test_telegram_platform_stop_no_application_is_noop():
|
|||
@pytest.mark.asyncio
|
||||
async def test_with_retry_returns_none_when_message_not_modified_network_error():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import NetworkError, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
|
||||
|
|
@ -54,7 +55,7 @@ async def test_with_retry_returns_none_when_message_not_modified_network_error()
|
|||
@pytest.mark.asyncio
|
||||
async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import NetworkError, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
|
||||
|
|
@ -75,7 +76,7 @@ async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
|
|||
@pytest.mark.asyncio
|
||||
async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import RetryAfter, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
|
||||
|
|
@ -96,7 +97,7 @@ async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
|
|||
@pytest.mark.asyncio
|
||||
async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramError, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
|
||||
|
|
@ -115,7 +116,7 @@ async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
|
|||
@pytest.mark.asyncio
|
||||
async def test_queue_send_message_without_limiter_calls_send_message():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
platform._limiter = None
|
||||
|
|
@ -130,7 +131,7 @@ async def test_queue_send_message_without_limiter_calls_send_message():
|
|||
@pytest.mark.asyncio
|
||||
async def test_queue_edit_message_without_limiter_calls_edit_message():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
platform._limiter = None
|
||||
|
|
@ -143,7 +144,7 @@ async def test_queue_edit_message_without_limiter_calls_edit_message():
|
|||
|
||||
def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
|
||||
|
|
@ -157,7 +158,7 @@ def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_start_command_replies_and_forwards():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
with patch.object(
|
||||
|
|
@ -174,7 +175,7 @@ async def test_on_start_command_replies_and_forwards():
|
|||
@pytest.mark.asyncio
|
||||
async def test_on_telegram_message_handler_error_sends_error_message():
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
|
||||
with patch.object(
|
||||
|
|
@ -200,7 +201,7 @@ async def test_on_telegram_message_handler_error_sends_error_message():
|
|||
@pytest.mark.asyncio
|
||||
async def test_telegram_start_retries_on_network_error(monkeypatch):
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import NetworkError, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="token", allowed_user_id=None)
|
||||
|
||||
|
|
@ -225,7 +226,7 @@ async def test_telegram_start_retries_on_network_error(monkeypatch):
|
|||
async def test_edit_message_with_text_exceeding_4096_raises():
|
||||
"""edit_message with text > 4096 raises TelegramError (BadRequest)."""
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramError, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
platform._application = MagicMock()
|
||||
|
|
@ -242,7 +243,7 @@ async def test_edit_message_with_text_exceeding_4096_raises():
|
|||
async def test_edit_message_empty_string():
|
||||
"""edit_message with empty string - Telegram accepts (no-op edit)."""
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
platform._application = MagicMock()
|
||||
|
|
@ -259,7 +260,7 @@ async def test_edit_message_empty_string():
|
|||
async def test_send_message_empty_string():
|
||||
"""send_message with empty string - Telegram may reject; we pass through."""
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
platform._application = MagicMock()
|
||||
|
|
@ -277,7 +278,7 @@ async def test_send_message_empty_string():
|
|||
async def test_on_telegram_message_non_text_update_ignored():
|
||||
"""Update with message.photo but no text returns early without calling handler."""
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
|
||||
handler = AsyncMock()
|
||||
|
|
@ -299,7 +300,7 @@ async def test_on_telegram_message_non_text_update_ignored():
|
|||
async def test_with_retry_message_not_found_returns_none():
|
||||
"""'message to edit not found' returns None without retry."""
|
||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||
from messaging.telegram import TelegramError, TelegramPlatform
|
||||
from messaging.platforms.telegram import TelegramPlatform
|
||||
|
||||
platform = TelegramPlatform(bot_token="t")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from messaging.telegram_markdown import (
|
||||
from messaging.rendering.telegram_markdown import (
|
||||
escape_md_v2,
|
||||
escape_md_v2_code,
|
||||
mdv2_bold,
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from providers.model_utils import (
|
||||
get_original_model,
|
||||
is_claude_model,
|
||||
normalize_model_name,
|
||||
strip_provider_prefixes,
|
||||
)
|
||||
|
||||
|
||||
def test_strip_provider_prefixes():
|
||||
assert strip_provider_prefixes("anthropic/claude-3") == "claude-3"
|
||||
assert strip_provider_prefixes("openai/gpt-4") == "gpt-4"
|
||||
assert strip_provider_prefixes("gemini/gemini-pro") == "gemini-pro"
|
||||
assert strip_provider_prefixes("no-prefix") == "no-prefix"
|
||||
|
||||
|
||||
def test_is_claude_model():
|
||||
assert is_claude_model("claude-3-sonnet") is True
|
||||
assert is_claude_model("claude-3-opus") is True
|
||||
assert is_claude_model("claude-3-haiku") is True
|
||||
assert is_claude_model("claude-2.1") is True
|
||||
assert is_claude_model("gpt-4") is False
|
||||
assert is_claude_model("gemini-pro") is False
|
||||
|
||||
|
||||
def test_normalize_model_name_claude_maps_to_default():
|
||||
default = "target-model"
|
||||
# Strips prefix AND maps to default
|
||||
assert normalize_model_name("anthropic/claude-3-sonnet", default) == default
|
||||
assert normalize_model_name("claude-3-opus", default) == default
|
||||
|
||||
|
||||
def test_normalize_model_name_non_claude_unchanged():
|
||||
default = "target-model"
|
||||
assert normalize_model_name("gpt-4", default) == "gpt-4"
|
||||
assert (
|
||||
normalize_model_name("openai/gpt-3.5-turbo", default) == "openai/gpt-3.5-turbo"
|
||||
)
|
||||
|
||||
|
||||
def test_get_original_model():
|
||||
assert get_original_model("any-model") == "any-model"
|
||||
|
||||
|
||||
def test_normalize_model_name_without_default(monkeypatch):
|
||||
monkeypatch.setenv("MODEL", "env-default-model")
|
||||
assert normalize_model_name("claude-3") == "env-default-model"
|
||||
|
||||
|
||||
# --- Parametrized Edge Case Tests ---
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected",
|
||||
[
|
||||
("anthropic/claude-3", "claude-3"),
|
||||
("openai/gpt-4", "gpt-4"),
|
||||
("gemini/gemini-pro", "gemini-pro"),
|
||||
("no-prefix", "no-prefix"),
|
||||
("", ""),
|
||||
("anthropic/", ""),
|
||||
("anthropic/openai/nested", "openai/nested"),
|
||||
],
|
||||
ids=[
|
||||
"anthropic",
|
||||
"openai",
|
||||
"gemini",
|
||||
"no_prefix",
|
||||
"empty_string",
|
||||
"prefix_only",
|
||||
"nested_prefix",
|
||||
],
|
||||
)
|
||||
def test_strip_provider_prefixes_parametrized(model, expected):
|
||||
"""Parametrized prefix stripping with edge cases."""
|
||||
assert strip_provider_prefixes(model) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,expected",
|
||||
[
|
||||
("claude-3-sonnet", True),
|
||||
("claude-3-opus", True),
|
||||
("claude-3-haiku", True),
|
||||
("claude-2.1", True),
|
||||
("gpt-4", False),
|
||||
("gemini-pro", False),
|
||||
("", False),
|
||||
("my-claude-wrapper", True), # "claude" as substring
|
||||
("CLAUDE-3-SONNET", True), # case insensitive
|
||||
("sonnet-v2", True), # "sonnet" identifier without "claude"
|
||||
("haiku-model", True), # "haiku" identifier
|
||||
],
|
||||
ids=[
|
||||
"sonnet",
|
||||
"opus",
|
||||
"haiku",
|
||||
"claude2",
|
||||
"gpt4",
|
||||
"gemini",
|
||||
"empty",
|
||||
"claude_substring",
|
||||
"uppercase",
|
||||
"sonnet_standalone",
|
||||
"haiku_standalone",
|
||||
],
|
||||
)
|
||||
def test_is_claude_model_parametrized(model, expected):
|
||||
"""Parametrized Claude model detection with edge cases."""
|
||||
assert is_claude_model(model) is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,default,expected",
|
||||
[
|
||||
("claude-3-sonnet", "target", "target"),
|
||||
("anthropic/claude-3-opus", "target", "target"),
|
||||
("gpt-4", "target", "gpt-4"),
|
||||
("openai/gpt-3.5-turbo", "target", "openai/gpt-3.5-turbo"),
|
||||
("", "target", ""), # empty string is not a claude model
|
||||
],
|
||||
ids=[
|
||||
"claude_mapped",
|
||||
"prefixed_claude",
|
||||
"non_claude",
|
||||
"prefixed_non_claude",
|
||||
"empty",
|
||||
],
|
||||
)
|
||||
def test_normalize_model_name_parametrized(model, default, expected):
|
||||
"""Parametrized model normalization."""
|
||||
assert normalize_model_name(model, default) == expected
|
||||
|
|
@ -3,9 +3,9 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from config.nim import NimSettings
|
||||
from providers.common.utils import set_if_not_none
|
||||
from providers.nvidia_nim.request import (
|
||||
_set_extra,
|
||||
_set_if_not_none,
|
||||
build_request_body,
|
||||
)
|
||||
|
||||
|
|
@ -13,12 +13,12 @@ from providers.nvidia_nim.request import (
|
|||
class TestSetIfNotNone:
|
||||
def test_value_not_none_sets(self):
|
||||
body = {}
|
||||
_set_if_not_none(body, "key", "value")
|
||||
set_if_not_none(body, "key", "value")
|
||||
assert body["key"] == "value"
|
||||
|
||||
def test_value_none_skips(self):
|
||||
body = {}
|
||||
_set_if_not_none(body, "key", None)
|
||||
set_if_not_none(body, "key", None)
|
||||
assert "key" not in body
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue