mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Backup/before cleanup 20260222 230402 (#58)
This commit is contained in:
parent
e2840095ce
commit
c4d8681000
43 changed files with 253 additions and 584 deletions
|
|
@ -29,7 +29,7 @@ HTTP_CONNECT_TIMEOUT=2
|
||||||
|
|
||||||
|
|
||||||
# Messaging Platform: "telegram" | "discord"
|
# Messaging Platform: "telegram" | "discord"
|
||||||
MESSAGING_PLATFORM=discord
|
MESSAGING_PLATFORM="discord"
|
||||||
MESSAGING_RATE_LIMIT=1
|
MESSAGING_RATE_LIMIT=1
|
||||||
MESSAGING_RATE_WINDOW=1
|
MESSAGING_RATE_WINDOW=1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
## CODING ENVIRONMENT
|
## CODING ENVIRONMENT
|
||||||
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
||||||
- Always use `uv run` to run files instead of the global `python` command.
|
- Always use `uv run` to run files instead of the global `python` command.
|
||||||
|
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
|
||||||
- Read `.env.example` for environment variables.
|
- Read `.env.example` for environment variables.
|
||||||
- All CI checks must pass; failing checks block merge.
|
- All CI checks must pass; failing checks block merge.
|
||||||
- Add tests for new changes (including edge cases), then run `uv run pytest`.
|
- Add tests for new changes (including edge cases), then run `uv run pytest`.
|
||||||
|
|
@ -38,7 +39,7 @@
|
||||||
|
|
||||||
## SUMMARY STANDARDS
|
## SUMMARY STANDARDS
|
||||||
- Summaries must be technical and granular.
|
- Summaries must be technical and granular.
|
||||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks].
|
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
|
||||||
|
|
||||||
## TOOLS
|
## TOOLS
|
||||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
## CODING ENVIRONMENT
|
## CODING ENVIRONMENT
|
||||||
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
||||||
- Always use `uv run` to run files instead of the global `python` command.
|
- Always use `uv run` to run files instead of the global `python` command.
|
||||||
|
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
|
||||||
- Read `.env.example` for environment variables.
|
- Read `.env.example` for environment variables.
|
||||||
- All CI checks must pass; failing checks block merge.
|
- All CI checks must pass; failing checks block merge.
|
||||||
- Add tests for new changes (including edge cases), then run `uv run pytest`.
|
- Add tests for new changes (including edge cases), then run `uv run pytest`.
|
||||||
|
|
@ -38,7 +39,7 @@
|
||||||
|
|
||||||
## SUMMARY STANDARDS
|
## SUMMARY STANDARDS
|
||||||
- Summaries must be technical and granular.
|
- Summaries must be technical and granular.
|
||||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks].
|
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
|
||||||
|
|
||||||
## TOOLS
|
## TOOLS
|
||||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||||
17
README.md
17
README.md
|
|
@ -289,11 +289,11 @@ uv sync --extra voice
|
||||||
Full list in [`nvidia_nim_models.json`](nvidia_nim_models.json).
|
Full list in [`nvidia_nim_models.json`](nvidia_nim_models.json).
|
||||||
|
|
||||||
Popular models:
|
Popular models:
|
||||||
- `qwen/qwen3.5-397b-a17b`
|
- `nvidia_nim/minimaxai/minimax-m2.5`
|
||||||
- `z-ai/glm5`
|
- `nvidia_nim/qwen/qwen3.5-397b-a17b`
|
||||||
- `stepfun-ai/step-3.5-flash`
|
- `nvidia_nim/z-ai/glm5`
|
||||||
- `moonshotai/kimi-k2.5`
|
- `nvidia_nim/stepfun-ai/step-3.5-flash`
|
||||||
- `minimaxai/minimax-m2.1`
|
- `nvidia_nim/moonshotai/kimi-k2.5`
|
||||||
|
|
||||||
Browse: [build.nvidia.com](https://build.nvidia.com/explore/discover)
|
Browse: [build.nvidia.com](https://build.nvidia.com/explore/discover)
|
||||||
|
|
||||||
|
|
@ -310,9 +310,9 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
|
||||||
Hundreds of models from StepFun, OpenAI, Anthropic, Google, and more.
|
Hundreds of models from StepFun, OpenAI, Anthropic, Google, and more.
|
||||||
|
|
||||||
Popular models:
|
Popular models:
|
||||||
- `stepfun/step-3.5-flash:free`
|
- `open_router/stepfun/step-3.5-flash:free`
|
||||||
- `deepseek/deepseek-r1-0528:free`
|
- `open_router/deepseek/deepseek-r1-0528:free`
|
||||||
- `openai/gpt-oss-120b:free`
|
- `open_router/openai/gpt-oss-120b:free`
|
||||||
|
|
||||||
Browse: [openrouter.ai/models](https://openrouter.ai/models)
|
Browse: [openrouter.ai/models](https://openrouter.ai/models)
|
||||||
|
|
||||||
|
|
@ -385,7 +385,6 @@ free-claude-code/
|
||||||
├── messaging/ # MessagingPlatform ABC + Discord/Telegram bots, session management
|
├── messaging/ # MessagingPlatform ABC + Discord/Telegram bots, session management
|
||||||
├── config/ # Settings, NIM config, logging
|
├── config/ # Settings, NIM config, logging
|
||||||
├── cli/ # CLI session and process management
|
├── cli/ # CLI session and process management
|
||||||
├── utils/ # Text utilities
|
|
||||||
└── tests/ # Pytest test suite
|
└── tests/ # Pytest test suite
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the messaging factory to create the right platform
|
# Use the messaging factory to create the right platform
|
||||||
from messaging.factory import create_messaging_platform
|
from messaging.platforms.factory import create_messaging_platform
|
||||||
|
|
||||||
messaging_platform = create_messaging_platform(
|
messaging_platform = create_messaging_platform(
|
||||||
platform_type=settings.messaging_platform,
|
platform_type=settings.messaging_platform,
|
||||||
|
|
|
||||||
|
|
@ -16,88 +16,84 @@ def get_settings() -> Settings:
|
||||||
return _get_settings()
|
return _get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_provider(settings: Settings) -> BaseProvider:
|
||||||
|
"""Construct and return a new provider instance from settings."""
|
||||||
|
if settings.provider_type == "nvidia_nim":
|
||||||
|
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=(
|
||||||
|
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||||
|
"Get a key at https://build.nvidia.com/settings/api-keys"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
from providers.nvidia_nim import NvidiaNimProvider
|
||||||
|
|
||||||
|
config = ProviderConfig(
|
||||||
|
api_key=settings.nvidia_nim_api_key,
|
||||||
|
base_url=NVIDIA_NIM_BASE_URL,
|
||||||
|
rate_limit=settings.provider_rate_limit,
|
||||||
|
rate_window=settings.provider_rate_window,
|
||||||
|
max_concurrency=settings.provider_max_concurrency,
|
||||||
|
http_read_timeout=settings.http_read_timeout,
|
||||||
|
http_write_timeout=settings.http_write_timeout,
|
||||||
|
http_connect_timeout=settings.http_connect_timeout,
|
||||||
|
)
|
||||||
|
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||||
|
elif settings.provider_type == "open_router":
|
||||||
|
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=(
|
||||||
|
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||||
|
"Get a key at https://openrouter.ai/keys"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
from providers.open_router import OpenRouterProvider
|
||||||
|
|
||||||
|
config = ProviderConfig(
|
||||||
|
api_key=settings.open_router_api_key,
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
rate_limit=settings.provider_rate_limit,
|
||||||
|
rate_window=settings.provider_rate_window,
|
||||||
|
max_concurrency=settings.provider_max_concurrency,
|
||||||
|
http_read_timeout=settings.http_read_timeout,
|
||||||
|
http_write_timeout=settings.http_write_timeout,
|
||||||
|
http_connect_timeout=settings.http_connect_timeout,
|
||||||
|
)
|
||||||
|
provider = OpenRouterProvider(config)
|
||||||
|
elif settings.provider_type == "lmstudio":
|
||||||
|
from providers.lmstudio import LMStudioProvider
|
||||||
|
|
||||||
|
config = ProviderConfig(
|
||||||
|
api_key="lm-studio",
|
||||||
|
base_url=settings.lm_studio_base_url,
|
||||||
|
rate_limit=settings.provider_rate_limit,
|
||||||
|
rate_window=settings.provider_rate_window,
|
||||||
|
max_concurrency=settings.provider_max_concurrency,
|
||||||
|
http_read_timeout=settings.http_read_timeout,
|
||||||
|
http_write_timeout=settings.http_write_timeout,
|
||||||
|
http_connect_timeout=settings.http_connect_timeout,
|
||||||
|
)
|
||||||
|
provider = LMStudioProvider(config)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||||
|
settings.provider_type,
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider_type: '{settings.provider_type}'. "
|
||||||
|
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||||
|
)
|
||||||
|
logger.info("Provider initialized: %s", settings.provider_type)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def get_provider() -> BaseProvider:
|
def get_provider() -> BaseProvider:
|
||||||
"""Get or create the provider instance based on settings.provider_type."""
|
"""Get or create the provider instance based on settings.provider_type."""
|
||||||
global _provider
|
global _provider
|
||||||
if _provider is None:
|
if _provider is None:
|
||||||
settings = get_settings()
|
_provider = _create_provider(get_settings())
|
||||||
|
|
||||||
if settings.provider_type == "nvidia_nim":
|
|
||||||
if (
|
|
||||||
not settings.nvidia_nim_api_key
|
|
||||||
or not settings.nvidia_nim_api_key.strip()
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail=(
|
|
||||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
|
||||||
"Get a key at https://build.nvidia.com/settings/api-keys"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
from providers.nvidia_nim import NvidiaNimProvider
|
|
||||||
|
|
||||||
config = ProviderConfig(
|
|
||||||
api_key=settings.nvidia_nim_api_key,
|
|
||||||
base_url=NVIDIA_NIM_BASE_URL,
|
|
||||||
rate_limit=settings.provider_rate_limit,
|
|
||||||
rate_window=settings.provider_rate_window,
|
|
||||||
max_concurrency=settings.provider_max_concurrency,
|
|
||||||
http_read_timeout=settings.http_read_timeout,
|
|
||||||
http_write_timeout=settings.http_write_timeout,
|
|
||||||
http_connect_timeout=settings.http_connect_timeout,
|
|
||||||
)
|
|
||||||
_provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
|
||||||
logger.info("Provider initialized: %s", settings.provider_type)
|
|
||||||
elif settings.provider_type == "open_router":
|
|
||||||
if (
|
|
||||||
not settings.open_router_api_key
|
|
||||||
or not settings.open_router_api_key.strip()
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail=(
|
|
||||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
|
||||||
"Get a key at https://openrouter.ai/keys"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
from providers.open_router import OpenRouterProvider
|
|
||||||
|
|
||||||
config = ProviderConfig(
|
|
||||||
api_key=settings.open_router_api_key,
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
rate_limit=settings.provider_rate_limit,
|
|
||||||
rate_window=settings.provider_rate_window,
|
|
||||||
max_concurrency=settings.provider_max_concurrency,
|
|
||||||
http_read_timeout=settings.http_read_timeout,
|
|
||||||
http_write_timeout=settings.http_write_timeout,
|
|
||||||
http_connect_timeout=settings.http_connect_timeout,
|
|
||||||
)
|
|
||||||
_provider = OpenRouterProvider(config)
|
|
||||||
logger.info("Provider initialized: %s", settings.provider_type)
|
|
||||||
elif settings.provider_type == "lmstudio":
|
|
||||||
from providers.lmstudio import LMStudioProvider
|
|
||||||
|
|
||||||
config = ProviderConfig(
|
|
||||||
api_key="lm-studio",
|
|
||||||
base_url=settings.lm_studio_base_url,
|
|
||||||
rate_limit=settings.provider_rate_limit,
|
|
||||||
rate_window=settings.provider_rate_window,
|
|
||||||
max_concurrency=settings.provider_max_concurrency,
|
|
||||||
http_read_timeout=settings.http_read_timeout,
|
|
||||||
http_write_timeout=settings.http_write_timeout,
|
|
||||||
http_connect_timeout=settings.http_connect_timeout,
|
|
||||||
)
|
|
||||||
_provider = LMStudioProvider(config)
|
|
||||||
logger.info("Provider initialized: %s", settings.provider_type)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
|
||||||
settings.provider_type,
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider_type: '{settings.provider_type}'. "
|
|
||||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
|
||||||
)
|
|
||||||
return _provider
|
return _provider
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,5 +104,10 @@ async def cleanup_provider():
|
||||||
client = getattr(_provider, "_client", None)
|
client = getattr(_provider, "_client", None)
|
||||||
if client and hasattr(client, "aclose"):
|
if client and hasattr(client, "aclose"):
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
elif client:
|
||||||
|
logger.warning(
|
||||||
|
"Provider client %r has no aclose(); skipping async cleanup",
|
||||||
|
type(client).__name__,
|
||||||
|
)
|
||||||
_provider = None
|
_provider = None
|
||||||
logger.debug("Provider cleanup completed")
|
logger.debug("Provider cleanup completed")
|
||||||
|
|
|
||||||
|
|
@ -29,14 +29,13 @@ def is_quota_check_request(request_data: MessagesRequest) -> bool:
|
||||||
def is_title_generation_request(request_data: MessagesRequest) -> bool:
|
def is_title_generation_request(request_data: MessagesRequest) -> bool:
|
||||||
"""Check if this is a conversation title generation request.
|
"""Check if this is a conversation title generation request.
|
||||||
|
|
||||||
Title generation requests typically contain the phrase
|
Title generation requests are detected by a system prompt containing
|
||||||
"write a 5-10 word title" in the user's message.
|
title extraction instructions, no tools, and a single user message.
|
||||||
"""
|
"""
|
||||||
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user":
|
if not request_data.system or request_data.tools:
|
||||||
text = extract_text_from_content(request_data.messages[-1].content)
|
return False
|
||||||
if "write a 5-10 word title" in text.lower():
|
system_text = extract_text_from_content(request_data.system).lower()
|
||||||
return True
|
return "new conversation topic" in system_text and "title" in system_text
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:
|
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from loguru import logger
|
||||||
from pydantic import BaseModel, field_validator, model_validator
|
from pydantic import BaseModel, field_validator, model_validator
|
||||||
|
|
||||||
from config.settings import get_settings
|
from config.settings import get_settings
|
||||||
from providers.model_utils import normalize_model_name
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Content Block Types
|
# Content Block Types
|
||||||
|
|
@ -112,10 +111,7 @@ class MessagesRequest(BaseModel):
|
||||||
if self.original_model is None:
|
if self.original_model is None:
|
||||||
self.original_model = self.model
|
self.original_model = self.model
|
||||||
|
|
||||||
# Use centralized model normalization
|
self.model = settings.model_name
|
||||||
normalized = normalize_model_name(self.model, settings.model_name)
|
|
||||||
if normalized != self.model:
|
|
||||||
self.model = normalized
|
|
||||||
|
|
||||||
if self.model != self.original_model:
|
if self.model != self.original_model:
|
||||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||||
|
|
@ -136,5 +132,4 @@ class TokenCountRequest(BaseModel):
|
||||||
def validate_model_field(cls, v, info):
|
def validate_model_field(cls, v, info):
|
||||||
"""Map any Claude model name to the configured model."""
|
"""Map any Claude model name to the configured model."""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
# Use centralized model normalization
|
return settings.model_name
|
||||||
return normalize_model_name(v, settings.model_name)
|
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
|
||||||
format=_serialize_with_context,
|
format=_serialize_with_context,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
mode="a",
|
mode="a",
|
||||||
|
rotation="50 MB",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Intercept stdlib logging: route all root logger output to loguru
|
# Intercept stdlib logging: route all root logger output to loguru
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
"""Platform-agnostic messaging layer."""
|
"""Platform-agnostic messaging layer."""
|
||||||
|
|
||||||
from .base import CLISession, MessagingPlatform, SessionManagerInterface
|
|
||||||
from .event_parser import parse_cli_event
|
from .event_parser import parse_cli_event
|
||||||
from .handler import ClaudeMessageHandler
|
from .handler import ClaudeMessageHandler
|
||||||
from .models import IncomingMessage
|
from .models import IncomingMessage
|
||||||
|
from .platforms.base import CLISession, MessagingPlatform, SessionManagerInterface
|
||||||
from .session import SessionStore
|
from .session import SessionStore
|
||||||
from .trees.data import MessageNode, MessageState, MessageTree
|
from .trees.data import MessageNode, MessageState, MessageTree
|
||||||
from .trees.queue_manager import TreeQueueManager
|
from .trees.queue_manager import TreeQueueManager
|
||||||
|
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
"""Backward-compatible re-export. Use messaging.platforms.base for new code."""
|
|
||||||
|
|
||||||
from .platforms.base import (
|
|
||||||
CLISession,
|
|
||||||
MessagingPlatform,
|
|
||||||
SessionManagerInterface,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["CLISession", "MessagingPlatform", "SessionManagerInterface"]
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
"""Backward-compatible re-export. Use messaging.platforms.discord for new code."""
|
|
||||||
|
|
||||||
from .platforms.discord import (
|
|
||||||
DISCORD_AVAILABLE,
|
|
||||||
DISCORD_MESSAGE_LIMIT,
|
|
||||||
DiscordPlatform,
|
|
||||||
_get_discord,
|
|
||||||
_parse_allowed_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"DISCORD_AVAILABLE",
|
|
||||||
"DISCORD_MESSAGE_LIMIT",
|
|
||||||
"DiscordPlatform",
|
|
||||||
"_get_discord",
|
|
||||||
"_parse_allowed_channels",
|
|
||||||
]
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
"""Backward-compatible re-export. Use messaging.rendering.discord_markdown for new code."""
|
|
||||||
|
|
||||||
from .rendering.discord_markdown import (
|
|
||||||
_is_gfm_table_header_line,
|
|
||||||
_normalize_gfm_tables,
|
|
||||||
discord_bold,
|
|
||||||
discord_code_inline,
|
|
||||||
escape_discord,
|
|
||||||
escape_discord_code,
|
|
||||||
format_status,
|
|
||||||
format_status_discord,
|
|
||||||
render_markdown_to_discord,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"_is_gfm_table_header_line",
|
|
||||||
"_normalize_gfm_tables",
|
|
||||||
"discord_bold",
|
|
||||||
"discord_code_inline",
|
|
||||||
"escape_discord",
|
|
||||||
"escape_discord_code",
|
|
||||||
"format_status",
|
|
||||||
"format_status_discord",
|
|
||||||
"render_markdown_to_discord",
|
|
||||||
]
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
"""Backward-compatible re-export. Use messaging.platforms.factory for new code."""
|
|
||||||
|
|
||||||
from .platforms.factory import create_messaging_platform
|
|
||||||
|
|
||||||
__all__ = ["create_messaging_platform"]
|
|
||||||
|
|
@ -95,18 +95,18 @@ class SessionStore:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load sessions: {e}")
|
logger.error(f"Failed to load sessions: {e}")
|
||||||
|
|
||||||
def _save(self) -> None:
|
def _snapshot(self) -> dict:
|
||||||
"""Persist sessions and trees to disk. Caller must hold self._lock."""
|
"""Snapshot current state for serialization. Caller must hold self._lock."""
|
||||||
try:
|
return {
|
||||||
data = {
|
"trees": dict(self._trees),
|
||||||
"trees": self._trees,
|
"node_to_tree": dict(self._node_to_tree),
|
||||||
"node_to_tree": self._node_to_tree,
|
"message_log": {k: list(v) for k, v in self._message_log.items()},
|
||||||
"message_log": self._message_log,
|
}
|
||||||
}
|
|
||||||
with open(self.storage_path, "w", encoding="utf-8") as f:
|
def _write_data(self, data: dict) -> None:
|
||||||
json.dump(data, f, indent=2)
|
"""Write data dict to disk. Must be called WITHOUT holding self._lock."""
|
||||||
except Exception as e:
|
with open(self.storage_path, "w", encoding="utf-8") as f:
|
||||||
logger.error(f"Failed to save sessions: {e}")
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
def _schedule_save(self) -> None:
|
def _schedule_save(self) -> None:
|
||||||
"""Schedule a debounced save. Caller must hold self._lock."""
|
"""Schedule a debounced save. Caller must hold self._lock."""
|
||||||
|
|
@ -126,22 +126,35 @@ class SessionStore:
|
||||||
if not self._dirty:
|
if not self._dirty:
|
||||||
self._save_timer = None
|
self._save_timer = None
|
||||||
return
|
return
|
||||||
self._save()
|
snapshot = self._snapshot()
|
||||||
self._dirty = False
|
self._dirty = False
|
||||||
self._save_timer = None
|
self._save_timer = None
|
||||||
|
try:
|
||||||
|
self._write_data(snapshot)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save sessions: {e}")
|
||||||
|
with self._lock:
|
||||||
|
self._dirty = True
|
||||||
|
|
||||||
def _flush_save(self) -> None:
|
def _flush_save(self) -> dict:
|
||||||
"""Immediate save, cancel any pending debounced save. Caller must hold self._lock."""
|
"""Cancel pending timer and snapshot current state. Caller must hold self._lock.
|
||||||
|
Returns snapshot dict; caller must call _write_data(snapshot) outside the lock."""
|
||||||
if self._save_timer is not None:
|
if self._save_timer is not None:
|
||||||
self._save_timer.cancel()
|
self._save_timer.cancel()
|
||||||
self._save_timer = None
|
self._save_timer = None
|
||||||
self._dirty = False
|
self._dirty = False
|
||||||
self._save()
|
return self._snapshot()
|
||||||
|
|
||||||
def flush_pending_save(self) -> None:
|
def flush_pending_save(self) -> None:
|
||||||
"""Flush any pending debounced save. Call on shutdown to avoid losing data."""
|
"""Flush any pending debounced save. Call on shutdown to avoid losing data."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._flush_save()
|
snapshot = self._flush_save()
|
||||||
|
try:
|
||||||
|
self._write_data(snapshot)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save sessions: {e}")
|
||||||
|
with self._lock:
|
||||||
|
self._dirty = True
|
||||||
|
|
||||||
def record_message_id(
|
def record_message_id(
|
||||||
self,
|
self,
|
||||||
|
|
@ -201,7 +214,13 @@ class SessionStore:
|
||||||
self._node_to_tree.clear()
|
self._node_to_tree.clear()
|
||||||
self._message_log.clear()
|
self._message_log.clear()
|
||||||
self._message_log_ids.clear()
|
self._message_log_ids.clear()
|
||||||
self._flush_save()
|
snapshot = self._flush_save()
|
||||||
|
try:
|
||||||
|
self._write_data(snapshot)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save sessions: {e}")
|
||||||
|
with self._lock:
|
||||||
|
self._dirty = True
|
||||||
|
|
||||||
# ==================== Tree Methods ====================
|
# ==================== Tree Methods ====================
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
"""Backward-compatible re-export. Use messaging.platforms.telegram for new code."""
|
|
||||||
|
|
||||||
from .platforms.telegram import (
|
|
||||||
TELEGRAM_AVAILABLE,
|
|
||||||
TelegramPlatform,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-export telegram.error types when python-telegram-bot is installed
|
|
||||||
__all__ = ["TELEGRAM_AVAILABLE", "TelegramPlatform"]
|
|
||||||
try:
|
|
||||||
from telegram.error import NetworkError, RetryAfter, TelegramError
|
|
||||||
|
|
||||||
__all__ += ["NetworkError", "RetryAfter", "TelegramError"]
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
"""Backward-compatible re-export. Use messaging.rendering.telegram_markdown for new code."""
|
|
||||||
|
|
||||||
from .rendering.telegram_markdown import (
|
|
||||||
escape_md_v2,
|
|
||||||
escape_md_v2_code,
|
|
||||||
escape_md_v2_link_url,
|
|
||||||
format_status,
|
|
||||||
mdv2_bold,
|
|
||||||
mdv2_code_inline,
|
|
||||||
render_markdown_to_mdv2,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"escape_md_v2",
|
|
||||||
"escape_md_v2_code",
|
|
||||||
"escape_md_v2_link_url",
|
|
||||||
"format_status",
|
|
||||||
"mdv2_bold",
|
|
||||||
"mdv2_code_inline",
|
|
||||||
"render_markdown_to_mdv2",
|
|
||||||
]
|
|
||||||
|
|
@ -105,17 +105,19 @@ class TreeRepository:
|
||||||
if not tree:
|
if not tree:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
pending = []
|
pending: list[MessageNode] = []
|
||||||
node = tree.get_node(node_id)
|
stack = [node_id]
|
||||||
if not node:
|
|
||||||
return []
|
|
||||||
|
|
||||||
for child_id in node.children_ids:
|
while stack:
|
||||||
child = tree.get_node(child_id)
|
current_id = stack.pop()
|
||||||
if child and child.state == MessageState.PENDING:
|
node = tree.get_node(current_id)
|
||||||
pending.append(child)
|
if not node:
|
||||||
# Recursively get children of pending children
|
continue
|
||||||
pending.extend(self.get_pending_children(child_id))
|
for child_id in node.children_ids:
|
||||||
|
child = tree.get_node(child_id)
|
||||||
|
if child and child.state == MessageState.PENDING:
|
||||||
|
pending.append(child)
|
||||||
|
stack.append(child_id)
|
||||||
|
|
||||||
return pending
|
return pending
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -476,13 +476,13 @@
|
||||||
"owned_by": "microsoft"
|
"owned_by": "microsoft"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "minimaxai/minimax-m2",
|
"id": "minimaxai/minimax-m2.1",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": 735790403,
|
"created": 735790403,
|
||||||
"owned_by": "minimaxai"
|
"owned_by": "minimaxai"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "minimaxai/minimax-m2.1",
|
"id": "minimaxai/minimax-m2.5",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": 735790403,
|
"created": 735790403,
|
||||||
"owned_by": "minimaxai"
|
"owned_by": "minimaxai"
|
||||||
|
|
@ -709,12 +709,6 @@
|
||||||
"created": 735790403,
|
"created": 735790403,
|
||||||
"owned_by": "nvidia"
|
"owned_by": "nvidia"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": "nvidia/llama-3.2-nemoretriever-300m-embed-v2",
|
|
||||||
"object": "model",
|
|
||||||
"created": 735790403,
|
|
||||||
"owned_by": "nvidia"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
"id": "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from .message_converter import (
|
||||||
)
|
)
|
||||||
from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason
|
from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason
|
||||||
from .think_parser import ContentChunk, ContentType, ThinkTagParser
|
from .think_parser import ContentChunk, ContentType, ThinkTagParser
|
||||||
|
from .utils import set_if_not_none
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnthropicToOpenAIConverter",
|
"AnthropicToOpenAIConverter",
|
||||||
|
|
@ -22,4 +23,5 @@ __all__ = [
|
||||||
"get_block_type",
|
"get_block_type",
|
||||||
"map_error",
|
"map_error",
|
||||||
"map_stop_reason",
|
"map_stop_reason",
|
||||||
|
"set_if_not_none",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
9
providers/common/utils.py
Normal file
9
providers/common/utils.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Shared utility helpers for provider request builders."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
||||||
|
"""Set body[key] = value only when value is not None."""
|
||||||
|
if value is not None:
|
||||||
|
body[key] = value
|
||||||
|
|
@ -5,15 +5,11 @@ from typing import Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from providers.common.message_converter import AnthropicToOpenAIConverter
|
from providers.common.message_converter import AnthropicToOpenAIConverter
|
||||||
|
from providers.common.utils import set_if_not_none
|
||||||
|
|
||||||
LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
|
LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
|
||||||
|
|
||||||
|
|
||||||
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
|
||||||
if value is not None:
|
|
||||||
body[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def build_request_body(request_data: Any) -> dict:
|
def build_request_body(request_data: Any) -> dict:
|
||||||
"""Build OpenAI-format request body from Anthropic request for LM Studio."""
|
"""Build OpenAI-format request body from Anthropic request for LM Studio."""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
max_tokens = getattr(request_data, "max_tokens", None)
|
max_tokens = getattr(request_data, "max_tokens", None)
|
||||||
_set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
|
set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
|
||||||
|
|
||||||
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
||||||
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
||||||
|
|
||||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||||
if stop_sequences:
|
if stop_sequences:
|
||||||
|
|
|
||||||
|
|
@ -1,86 +0,0 @@
|
||||||
"""Model name normalization utilities.
|
|
||||||
|
|
||||||
Centralizes model name mapping logic to avoid duplication across the codebase.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Provider prefixes to strip from model names
|
|
||||||
_PROVIDER_PREFIXES = ["anthropic/", "openai/", "gemini/"]
|
|
||||||
|
|
||||||
# Claude model identifiers
|
|
||||||
_CLAUDE_IDENTIFIERS = ["haiku", "sonnet", "opus", "claude"]
|
|
||||||
|
|
||||||
|
|
||||||
def strip_provider_prefixes(model: str) -> str:
|
|
||||||
"""
|
|
||||||
Strip provider prefixes from model name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model name, possibly with prefix
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model name without provider prefix
|
|
||||||
"""
|
|
||||||
for prefix in _PROVIDER_PREFIXES:
|
|
||||||
if model.startswith(prefix):
|
|
||||||
return model[len(prefix) :]
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def is_claude_model(model: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a model name identifies as a Claude model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The (prefix-stripped) model name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if this is a Claude model
|
|
||||||
"""
|
|
||||||
model_lower = model.lower()
|
|
||||||
return any(name in model_lower for name in _CLAUDE_IDENTIFIERS)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_model_name(model: str, default_model: str | None = None) -> str:
|
|
||||||
"""
|
|
||||||
Normalize a model name by stripping prefixes and mapping to default if needed.
|
|
||||||
|
|
||||||
This is the central function for model name normalization across the API.
|
|
||||||
It strips provider prefixes and maps Claude model names to the configured model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model name (may include provider prefix)
|
|
||||||
default_model: The default model to use for Claude models.
|
|
||||||
If None, uses settings.model from config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized model name (original if not a Claude model, mapped if Claude)
|
|
||||||
"""
|
|
||||||
# Strip provider prefixes
|
|
||||||
clean = strip_provider_prefixes(model)
|
|
||||||
|
|
||||||
# Map Claude models to default
|
|
||||||
if is_claude_model(clean):
|
|
||||||
if default_model is None:
|
|
||||||
# Use environment/config default
|
|
||||||
default_model = os.getenv("MODEL", "moonshotai/kimi-k2-thinking")
|
|
||||||
return default_model
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_original_model(model: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the original model name, storing it before normalization.
|
|
||||||
|
|
||||||
Convenience function that returns the input unchanged, intended to be
|
|
||||||
called alongside normalize_model_name to capture the original.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The model name unchanged (for documentation purposes)
|
|
||||||
"""
|
|
||||||
return model
|
|
||||||
|
|
@ -6,11 +6,7 @@ from loguru import logger
|
||||||
|
|
||||||
from config.nim import NimSettings
|
from config.nim import NimSettings
|
||||||
from providers.common.message_converter import AnthropicToOpenAIConverter
|
from providers.common.message_converter import AnthropicToOpenAIConverter
|
||||||
|
from providers.common.utils import set_if_not_none
|
||||||
|
|
||||||
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
|
||||||
if value is not None:
|
|
||||||
body[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def _set_extra(
|
def _set_extra(
|
||||||
|
|
@ -52,15 +48,15 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
|
||||||
max_tokens = nim.max_tokens
|
max_tokens = nim.max_tokens
|
||||||
elif nim.max_tokens:
|
elif nim.max_tokens:
|
||||||
max_tokens = min(max_tokens, nim.max_tokens)
|
max_tokens = min(max_tokens, nim.max_tokens)
|
||||||
_set_if_not_none(body, "max_tokens", max_tokens)
|
set_if_not_none(body, "max_tokens", max_tokens)
|
||||||
|
|
||||||
req_temperature = getattr(request_data, "temperature", None)
|
req_temperature = getattr(request_data, "temperature", None)
|
||||||
temperature = req_temperature if req_temperature is not None else nim.temperature
|
temperature = req_temperature if req_temperature is not None else nim.temperature
|
||||||
_set_if_not_none(body, "temperature", temperature)
|
set_if_not_none(body, "temperature", temperature)
|
||||||
|
|
||||||
req_top_p = getattr(request_data, "top_p", None)
|
req_top_p = getattr(request_data, "top_p", None)
|
||||||
top_p = req_top_p if req_top_p is not None else nim.top_p
|
top_p = req_top_p if req_top_p is not None else nim.top_p
|
||||||
_set_if_not_none(body, "top_p", top_p)
|
set_if_not_none(body, "top_p", top_p)
|
||||||
|
|
||||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||||
if stop_sequences:
|
if stop_sequences:
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,11 @@ from typing import Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from providers.common.message_converter import AnthropicToOpenAIConverter
|
from providers.common.message_converter import AnthropicToOpenAIConverter
|
||||||
|
from providers.common.utils import set_if_not_none
|
||||||
|
|
||||||
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
|
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
|
||||||
|
|
||||||
|
|
||||||
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
|
|
||||||
if value is not None:
|
|
||||||
body[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def build_request_body(request_data: Any) -> dict:
|
def build_request_body(request_data: Any) -> dict:
|
||||||
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
|
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
max_tokens = getattr(request_data, "max_tokens", None)
|
max_tokens = getattr(request_data, "max_tokens", None)
|
||||||
_set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
|
set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
|
||||||
|
|
||||||
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
|
||||||
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
|
||||||
|
|
||||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||||
if stop_sequences:
|
if stop_sequences:
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||||
patch(
|
patch(
|
||||||
"messaging.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform if messaging_enabled else None,
|
return_value=fake_platform if messaging_enabled else None,
|
||||||
) as create_platform,
|
) as create_platform,
|
||||||
patch("messaging.session.SessionStore", return_value=session_store),
|
patch("messaging.session.SessionStore", return_value=session_store),
|
||||||
|
|
@ -195,7 +195,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||||
patch(
|
patch(
|
||||||
"messaging.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform,
|
return_value=fake_platform,
|
||||||
),
|
),
|
||||||
patch("messaging.session.SessionStore", return_value=session_store),
|
patch("messaging.session.SessionStore", return_value=session_store),
|
||||||
|
|
@ -234,7 +234,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||||
patch(
|
patch(
|
||||||
"messaging.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
side_effect=ImportError("discord not installed"),
|
side_effect=ImportError("discord not installed"),
|
||||||
),
|
),
|
||||||
TestClient(app),
|
TestClient(app),
|
||||||
|
|
@ -284,7 +284,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||||
patch(
|
patch(
|
||||||
"messaging.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform,
|
return_value=fake_platform,
|
||||||
),
|
),
|
||||||
patch("messaging.session.SessionStore", return_value=session_store),
|
patch("messaging.session.SessionStore", return_value=session_store),
|
||||||
|
|
@ -336,7 +336,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
|
||||||
patch.object(api_app_mod, "get_settings", return_value=settings),
|
patch.object(api_app_mod, "get_settings", return_value=settings),
|
||||||
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
|
||||||
patch(
|
patch(
|
||||||
"messaging.factory.create_messaging_platform",
|
"messaging.platforms.factory.create_messaging_platform",
|
||||||
return_value=fake_platform,
|
return_value=fake_platform,
|
||||||
),
|
),
|
||||||
patch("messaging.session.SessionStore", return_value=session_store),
|
patch("messaging.session.SessionStore", return_value=session_store),
|
||||||
|
|
|
||||||
|
|
@ -25,18 +25,6 @@ def test_messages_request_map_model_claude_to_default(mock_settings):
|
||||||
assert request.original_model == "claude-3-opus"
|
assert request.original_model == "claude-3-opus"
|
||||||
|
|
||||||
|
|
||||||
def test_messages_request_map_model_non_claude_unchanged(mock_settings):
|
|
||||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
|
||||||
request = MessagesRequest(
|
|
||||||
model="gpt-4",
|
|
||||||
max_tokens=100,
|
|
||||||
messages=[Message(role="user", content="hello")],
|
|
||||||
)
|
|
||||||
|
|
||||||
# normalize_model_name returns original if not Claude
|
|
||||||
assert request.model == "gpt-4"
|
|
||||||
|
|
||||||
|
|
||||||
def test_messages_request_map_model_with_provider_prefix(mock_settings):
|
def test_messages_request_map_model_with_provider_prefix(mock_settings):
|
||||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||||
request = MessagesRequest(
|
request = MessagesRequest(
|
||||||
|
|
|
||||||
|
|
@ -112,68 +112,42 @@ class TestQuotaCheckRequest:
|
||||||
class TestTitleGenerationRequest:
|
class TestTitleGenerationRequest:
|
||||||
"""Tests for is_title_generation_request function."""
|
"""Tests for is_title_generation_request function."""
|
||||||
|
|
||||||
def test_title_generation_simple(self):
|
def _title_gen_system(self) -> list[MagicMock]:
|
||||||
"""Test title generation detection with target phrase."""
|
|
||||||
msg = MagicMock(spec=Message)
|
|
||||||
msg.role = "user"
|
|
||||||
msg.content = "Please write a 5-10 word title for this conversation"
|
|
||||||
|
|
||||||
req = MagicMock(spec=MessagesRequest)
|
|
||||||
req.messages = [msg]
|
|
||||||
|
|
||||||
assert is_title_generation_request(req) is True
|
|
||||||
|
|
||||||
def test_title_generation_case_insensitive(self):
|
|
||||||
"""Test title generation is case insensitive."""
|
|
||||||
msg = MagicMock(spec=Message)
|
|
||||||
msg.role = "user"
|
|
||||||
msg.content = "Write a 5-10 Word Title please"
|
|
||||||
|
|
||||||
req = MagicMock(spec=MessagesRequest)
|
|
||||||
req.messages = [msg]
|
|
||||||
|
|
||||||
assert is_title_generation_request(req) is True
|
|
||||||
|
|
||||||
def test_title_generation_list_content(self):
|
|
||||||
"""Test title generation with list content blocks."""
|
|
||||||
block = MagicMock()
|
block = MagicMock()
|
||||||
block.text = "Write a 5-10 word title"
|
block.text = "Analyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title."
|
||||||
|
return [block]
|
||||||
msg = MagicMock(spec=Message)
|
|
||||||
msg.role = "user"
|
|
||||||
msg.content = [block]
|
|
||||||
|
|
||||||
|
def test_title_generation_detected_via_system(self):
|
||||||
|
"""Title gen detected by system prompt containing topic/title keywords."""
|
||||||
req = MagicMock(spec=MessagesRequest)
|
req = MagicMock(spec=MessagesRequest)
|
||||||
req.messages = [msg]
|
req.system = self._title_gen_system()
|
||||||
|
req.tools = None
|
||||||
|
|
||||||
assert is_title_generation_request(req) is True
|
assert is_title_generation_request(req) is True
|
||||||
|
|
||||||
def test_not_title_generation_no_phrase(self):
|
def test_title_generation_not_detected_with_tools(self):
|
||||||
"""Test not title generation without target phrase."""
|
"""Not detected when tools are present (main conversation, not title gen)."""
|
||||||
msg = MagicMock(spec=Message)
|
|
||||||
msg.role = "user"
|
|
||||||
msg.content = "Hello world, how are you?"
|
|
||||||
|
|
||||||
req = MagicMock(spec=MessagesRequest)
|
req = MagicMock(spec=MessagesRequest)
|
||||||
req.messages = [msg]
|
req.system = self._title_gen_system()
|
||||||
|
req.tools = [MagicMock()]
|
||||||
|
|
||||||
assert is_title_generation_request(req) is False
|
assert is_title_generation_request(req) is False
|
||||||
|
|
||||||
def test_not_title_generation_wrong_role(self):
|
def test_title_generation_not_detected_no_system(self):
|
||||||
"""Test not title generation when last message is not from user."""
|
"""Not detected when system is absent."""
|
||||||
msg = MagicMock(spec=Message)
|
|
||||||
msg.role = "assistant"
|
|
||||||
msg.content = "Write a 5-10 word title"
|
|
||||||
|
|
||||||
req = MagicMock(spec=MessagesRequest)
|
req = MagicMock(spec=MessagesRequest)
|
||||||
req.messages = [msg]
|
req.system = None
|
||||||
|
req.tools = None
|
||||||
|
|
||||||
assert is_title_generation_request(req) is False
|
assert is_title_generation_request(req) is False
|
||||||
|
|
||||||
def test_not_title_generation_empty_messages(self):
|
def test_title_generation_not_detected_unrelated_system(self):
|
||||||
"""Test not title generation when no messages."""
|
"""Not detected when system prompt has no topic/title keywords."""
|
||||||
|
block = MagicMock()
|
||||||
|
block.text = "You are a helpful assistant."
|
||||||
req = MagicMock(spec=MessagesRequest)
|
req = MagicMock(spec=MessagesRequest)
|
||||||
req.messages = []
|
req.system = [block]
|
||||||
|
req.tools = None
|
||||||
|
|
||||||
assert is_title_generation_request(req) is False
|
assert is_title_generation_request(req) is False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,12 @@ from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from config.nim import NimSettings
|
from config.nim import NimSettings
|
||||||
from messaging.base import CLISession, MessagingPlatform, SessionManagerInterface
|
|
||||||
from messaging.models import IncomingMessage
|
from messaging.models import IncomingMessage
|
||||||
|
from messaging.platforms.base import (
|
||||||
|
CLISession,
|
||||||
|
MessagingPlatform,
|
||||||
|
SessionManagerInterface,
|
||||||
|
)
|
||||||
from messaging.session import SessionStore
|
from messaging.session import SessionStore
|
||||||
from providers.base import ProviderConfig
|
from providers.base import ProviderConfig
|
||||||
from providers.nvidia_nim import NvidiaNimProvider
|
from providers.nvidia_nim import NvidiaNimProvider
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Tests for messaging/discord_markdown.py."""
|
"""Tests for messaging/rendering/discord_markdown.py."""
|
||||||
|
|
||||||
from messaging.discord_markdown import (
|
from messaging.rendering.discord_markdown import (
|
||||||
_is_gfm_table_header_line,
|
_is_gfm_table_header_line,
|
||||||
_normalize_gfm_tables,
|
_normalize_gfm_tables,
|
||||||
discord_bold,
|
discord_bold,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from messaging.discord import (
|
from messaging.platforms.discord import (
|
||||||
DISCORD_AVAILABLE,
|
DISCORD_AVAILABLE,
|
||||||
DiscordPlatform,
|
DiscordPlatform,
|
||||||
_get_discord,
|
_get_discord,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from messaging.telegram_markdown import (
|
from messaging.rendering.telegram_markdown import (
|
||||||
escape_md_v2,
|
escape_md_v2,
|
||||||
escape_md_v2_code,
|
escape_md_v2_code,
|
||||||
mdv2_bold,
|
mdv2_bold,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
|
|
||||||
from messaging.handler import ClaudeMessageHandler
|
from messaging.handler import ClaudeMessageHandler
|
||||||
from messaging.models import IncomingMessage
|
from messaging.models import IncomingMessage
|
||||||
from messaging.telegram_markdown import render_markdown_to_mdv2
|
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
|
||||||
from messaging.trees.data import MessageNode, MessageState
|
from messaging.trees.data import MessageNode, MessageState
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ class TestMessagingBase:
|
||||||
|
|
||||||
def test_platform_is_abstract(self):
|
def test_platform_is_abstract(self):
|
||||||
"""Verify MessagingPlatform cannot be instantiated."""
|
"""Verify MessagingPlatform cannot be instantiated."""
|
||||||
from messaging.base import MessagingPlatform
|
from messaging.platforms.base import MessagingPlatform
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
MessagingPlatform()
|
MessagingPlatform()
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from messaging.factory import create_messaging_platform
|
from messaging.platforms.factory import create_messaging_platform
|
||||||
|
|
||||||
|
|
||||||
class TestCreateMessagingPlatform:
|
class TestCreateMessagingPlatform:
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
from telegram.error import NetworkError, RetryAfter, TelegramError
|
from telegram.error import NetworkError, RetryAfter, TelegramError
|
||||||
|
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def telegram_platform():
|
def telegram_platform():
|
||||||
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
|
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
|
||||||
return platform
|
return platform
|
||||||
|
|
||||||
|
|
@ -76,7 +76,7 @@ async def test_telegram_no_retry_on_bad_request(telegram_platform):
|
||||||
|
|
||||||
def test_handler_build_message_hardening():
|
def test_handler_build_message_hardening():
|
||||||
# Formatting hardening now lives in TranscriptBuffer rendering.
|
# Formatting hardening now lives in TranscriptBuffer rendering.
|
||||||
from messaging.telegram_markdown import (
|
from messaging.rendering.telegram_markdown import (
|
||||||
escape_md_v2,
|
escape_md_v2,
|
||||||
escape_md_v2_code,
|
escape_md_v2_code,
|
||||||
mdv2_bold,
|
mdv2_bold,
|
||||||
|
|
@ -112,7 +112,7 @@ def test_handler_build_message_hardening():
|
||||||
|
|
||||||
def test_render_output_never_exceeds_4096():
|
def test_render_output_never_exceeds_4096():
|
||||||
"""Transcript render with various status lengths never exceeds Telegram 4096 limit."""
|
"""Transcript render with various status lengths never exceeds Telegram 4096 limit."""
|
||||||
from messaging.telegram_markdown import (
|
from messaging.rendering.telegram_markdown import (
|
||||||
escape_md_v2,
|
escape_md_v2,
|
||||||
escape_md_v2_code,
|
escape_md_v2_code,
|
||||||
mdv2_bold,
|
mdv2_bold,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from messaging.telegram_markdown import (
|
from messaging.rendering.telegram_markdown import (
|
||||||
escape_md_v2,
|
escape_md_v2,
|
||||||
escape_md_v2_code,
|
escape_md_v2_code,
|
||||||
mdv2_bold,
|
mdv2_bold,
|
||||||
|
|
@ -74,7 +74,7 @@ def test_empty_components_with_status(handler):
|
||||||
|
|
||||||
def test_render_markdown_unclosed_markdown():
|
def test_render_markdown_unclosed_markdown():
|
||||||
"""Malformed markdown (e.g. unclosed *) does not crash and produces acceptable output."""
|
"""Malformed markdown (e.g. unclosed *) does not crash and produces acceptable output."""
|
||||||
from messaging.telegram_markdown import render_markdown_to_mdv2
|
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
|
||||||
|
|
||||||
md = "*bold without close"
|
md = "*bold without close"
|
||||||
out = render_markdown_to_mdv2(md)
|
out = render_markdown_to_mdv2(md)
|
||||||
|
|
@ -84,7 +84,7 @@ def test_render_markdown_unclosed_markdown():
|
||||||
|
|
||||||
def test_escape_md_v2_unicode_emoji():
|
def test_escape_md_v2_unicode_emoji():
|
||||||
"""Unicode and emoji pass through correctly (no special char escaping needed)."""
|
"""Unicode and emoji pass through correctly (no special char escaping needed)."""
|
||||||
from messaging.telegram_markdown import escape_md_v2, escape_md_v2_code
|
from messaging.rendering.telegram_markdown import escape_md_v2, escape_md_v2_code
|
||||||
|
|
||||||
text = "Hello 世界 🎉 café"
|
text = "Hello 世界 🎉 café"
|
||||||
assert escape_md_v2(text) == text
|
assert escape_md_v2(text) == text
|
||||||
|
|
|
||||||
|
|
@ -81,11 +81,13 @@ class TestSessionStoreSaveEdgeCases:
|
||||||
"""Tests for save failure handling."""
|
"""Tests for save failure handling."""
|
||||||
|
|
||||||
def test_save_io_error_handled(self, tmp_store):
|
def test_save_io_error_handled(self, tmp_store):
|
||||||
"""Write failure in _save() is logged but doesn't raise."""
|
"""Write failure in _write_data() raises (callers handle the error)."""
|
||||||
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
|
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
|
||||||
with patch("builtins.open", side_effect=OSError("disk full")):
|
with (
|
||||||
tmp_store._save()
|
patch("builtins.open", side_effect=OSError("disk full")),
|
||||||
# Should not raise
|
pytest.raises(OSError),
|
||||||
|
):
|
||||||
|
tmp_store._write_data(tmp_store._snapshot())
|
||||||
|
|
||||||
|
|
||||||
class TestSessionStoreClearAll:
|
class TestSessionStoreClearAll:
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def telegram_platform():
|
def telegram_platform():
|
||||||
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
|
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
|
||||||
return platform
|
return platform
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,12 @@ from datetime import timedelta
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from telegram.error import NetworkError, RetryAfter, TelegramError
|
||||||
|
|
||||||
|
|
||||||
def test_telegram_platform_init_raises_when_dependency_missing():
|
def test_telegram_platform_init_raises_when_dependency_missing():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", False):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", False):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
with pytest.raises(ImportError):
|
with pytest.raises(ImportError):
|
||||||
TelegramPlatform(bot_token="x")
|
TelegramPlatform(bot_token="x")
|
||||||
|
|
@ -19,7 +20,7 @@ async def test_telegram_platform_start_requires_token():
|
||||||
patch.dict("os.environ", {}, clear=True),
|
patch.dict("os.environ", {}, clear=True),
|
||||||
patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True),
|
patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True),
|
||||||
):
|
):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token=None)
|
platform = TelegramPlatform(bot_token=None)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|
@ -29,7 +30,7 @@ async def test_telegram_platform_start_requires_token():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_telegram_platform_stop_no_application_is_noop():
|
async def test_telegram_platform_stop_no_application_is_noop():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
platform._application = None
|
platform._application = None
|
||||||
|
|
@ -41,7 +42,7 @@ async def test_telegram_platform_stop_no_application_is_noop():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_retry_returns_none_when_message_not_modified_network_error():
|
async def test_with_retry_returns_none_when_message_not_modified_network_error():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import NetworkError, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
|
|
||||||
|
|
@ -54,7 +55,7 @@ async def test_with_retry_returns_none_when_message_not_modified_network_error()
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
|
async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import NetworkError, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
|
|
||||||
|
|
@ -75,7 +76,7 @@ async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
|
async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import RetryAfter, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
|
|
||||||
|
|
@ -96,7 +97,7 @@ async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
|
async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramError, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
|
|
||||||
|
|
@ -115,7 +116,7 @@ async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_queue_send_message_without_limiter_calls_send_message():
|
async def test_queue_send_message_without_limiter_calls_send_message():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
platform._limiter = None
|
platform._limiter = None
|
||||||
|
|
@ -130,7 +131,7 @@ async def test_queue_send_message_without_limiter_calls_send_message():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_queue_edit_message_without_limiter_calls_edit_message():
|
async def test_queue_edit_message_without_limiter_calls_edit_message():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
platform._limiter = None
|
platform._limiter = None
|
||||||
|
|
@ -143,7 +144,7 @@ async def test_queue_edit_message_without_limiter_calls_edit_message():
|
||||||
|
|
||||||
def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
|
def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
|
|
||||||
|
|
@ -157,7 +158,7 @@ def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_start_command_replies_and_forwards():
|
async def test_on_start_command_replies_and_forwards():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
with patch.object(
|
with patch.object(
|
||||||
|
|
@ -174,7 +175,7 @@ async def test_on_start_command_replies_and_forwards():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_telegram_message_handler_error_sends_error_message():
|
async def test_on_telegram_message_handler_error_sends_error_message():
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
|
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
|
||||||
with patch.object(
|
with patch.object(
|
||||||
|
|
@ -200,7 +201,7 @@ async def test_on_telegram_message_handler_error_sends_error_message():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_telegram_start_retries_on_network_error(monkeypatch):
|
async def test_telegram_start_retries_on_network_error(monkeypatch):
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import NetworkError, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="token", allowed_user_id=None)
|
platform = TelegramPlatform(bot_token="token", allowed_user_id=None)
|
||||||
|
|
||||||
|
|
@ -225,7 +226,7 @@ async def test_telegram_start_retries_on_network_error(monkeypatch):
|
||||||
async def test_edit_message_with_text_exceeding_4096_raises():
|
async def test_edit_message_with_text_exceeding_4096_raises():
|
||||||
"""edit_message with text > 4096 raises TelegramError (BadRequest)."""
|
"""edit_message with text > 4096 raises TelegramError (BadRequest)."""
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramError, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
platform._application = MagicMock()
|
platform._application = MagicMock()
|
||||||
|
|
@ -242,7 +243,7 @@ async def test_edit_message_with_text_exceeding_4096_raises():
|
||||||
async def test_edit_message_empty_string():
|
async def test_edit_message_empty_string():
|
||||||
"""edit_message with empty string - Telegram accepts (no-op edit)."""
|
"""edit_message with empty string - Telegram accepts (no-op edit)."""
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
platform._application = MagicMock()
|
platform._application = MagicMock()
|
||||||
|
|
@ -259,7 +260,7 @@ async def test_edit_message_empty_string():
|
||||||
async def test_send_message_empty_string():
|
async def test_send_message_empty_string():
|
||||||
"""send_message with empty string - Telegram may reject; we pass through."""
|
"""send_message with empty string - Telegram may reject; we pass through."""
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
platform._application = MagicMock()
|
platform._application = MagicMock()
|
||||||
|
|
@ -277,7 +278,7 @@ async def test_send_message_empty_string():
|
||||||
async def test_on_telegram_message_non_text_update_ignored():
|
async def test_on_telegram_message_non_text_update_ignored():
|
||||||
"""Update with message.photo but no text returns early without calling handler."""
|
"""Update with message.photo but no text returns early without calling handler."""
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
|
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
|
||||||
handler = AsyncMock()
|
handler = AsyncMock()
|
||||||
|
|
@ -299,7 +300,7 @@ async def test_on_telegram_message_non_text_update_ignored():
|
||||||
async def test_with_retry_message_not_found_returns_none():
|
async def test_with_retry_message_not_found_returns_none():
|
||||||
"""'message to edit not found' returns None without retry."""
|
"""'message to edit not found' returns None without retry."""
|
||||||
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
|
||||||
from messaging.telegram import TelegramError, TelegramPlatform
|
from messaging.platforms.telegram import TelegramPlatform
|
||||||
|
|
||||||
platform = TelegramPlatform(bot_token="t")
|
platform = TelegramPlatform(bot_token="t")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from messaging.telegram_markdown import (
|
from messaging.rendering.telegram_markdown import (
|
||||||
escape_md_v2,
|
escape_md_v2,
|
||||||
escape_md_v2_code,
|
escape_md_v2_code,
|
||||||
mdv2_bold,
|
mdv2_bold,
|
||||||
|
|
|
||||||
|
|
@ -1,133 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
from providers.model_utils import (
|
|
||||||
get_original_model,
|
|
||||||
is_claude_model,
|
|
||||||
normalize_model_name,
|
|
||||||
strip_provider_prefixes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_strip_provider_prefixes():
|
|
||||||
assert strip_provider_prefixes("anthropic/claude-3") == "claude-3"
|
|
||||||
assert strip_provider_prefixes("openai/gpt-4") == "gpt-4"
|
|
||||||
assert strip_provider_prefixes("gemini/gemini-pro") == "gemini-pro"
|
|
||||||
assert strip_provider_prefixes("no-prefix") == "no-prefix"
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_claude_model():
|
|
||||||
assert is_claude_model("claude-3-sonnet") is True
|
|
||||||
assert is_claude_model("claude-3-opus") is True
|
|
||||||
assert is_claude_model("claude-3-haiku") is True
|
|
||||||
assert is_claude_model("claude-2.1") is True
|
|
||||||
assert is_claude_model("gpt-4") is False
|
|
||||||
assert is_claude_model("gemini-pro") is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_model_name_claude_maps_to_default():
|
|
||||||
default = "target-model"
|
|
||||||
# Strips prefix AND maps to default
|
|
||||||
assert normalize_model_name("anthropic/claude-3-sonnet", default) == default
|
|
||||||
assert normalize_model_name("claude-3-opus", default) == default
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_model_name_non_claude_unchanged():
|
|
||||||
default = "target-model"
|
|
||||||
assert normalize_model_name("gpt-4", default) == "gpt-4"
|
|
||||||
assert (
|
|
||||||
normalize_model_name("openai/gpt-3.5-turbo", default) == "openai/gpt-3.5-turbo"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_original_model():
|
|
||||||
assert get_original_model("any-model") == "any-model"
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_model_name_without_default(monkeypatch):
|
|
||||||
monkeypatch.setenv("MODEL", "env-default-model")
|
|
||||||
assert normalize_model_name("claude-3") == "env-default-model"
|
|
||||||
|
|
||||||
|
|
||||||
# --- Parametrized Edge Case Tests ---
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model,expected",
|
|
||||||
[
|
|
||||||
("anthropic/claude-3", "claude-3"),
|
|
||||||
("openai/gpt-4", "gpt-4"),
|
|
||||||
("gemini/gemini-pro", "gemini-pro"),
|
|
||||||
("no-prefix", "no-prefix"),
|
|
||||||
("", ""),
|
|
||||||
("anthropic/", ""),
|
|
||||||
("anthropic/openai/nested", "openai/nested"),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"anthropic",
|
|
||||||
"openai",
|
|
||||||
"gemini",
|
|
||||||
"no_prefix",
|
|
||||||
"empty_string",
|
|
||||||
"prefix_only",
|
|
||||||
"nested_prefix",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_strip_provider_prefixes_parametrized(model, expected):
|
|
||||||
"""Parametrized prefix stripping with edge cases."""
|
|
||||||
assert strip_provider_prefixes(model) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model,expected",
|
|
||||||
[
|
|
||||||
("claude-3-sonnet", True),
|
|
||||||
("claude-3-opus", True),
|
|
||||||
("claude-3-haiku", True),
|
|
||||||
("claude-2.1", True),
|
|
||||||
("gpt-4", False),
|
|
||||||
("gemini-pro", False),
|
|
||||||
("", False),
|
|
||||||
("my-claude-wrapper", True), # "claude" as substring
|
|
||||||
("CLAUDE-3-SONNET", True), # case insensitive
|
|
||||||
("sonnet-v2", True), # "sonnet" identifier without "claude"
|
|
||||||
("haiku-model", True), # "haiku" identifier
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"sonnet",
|
|
||||||
"opus",
|
|
||||||
"haiku",
|
|
||||||
"claude2",
|
|
||||||
"gpt4",
|
|
||||||
"gemini",
|
|
||||||
"empty",
|
|
||||||
"claude_substring",
|
|
||||||
"uppercase",
|
|
||||||
"sonnet_standalone",
|
|
||||||
"haiku_standalone",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_is_claude_model_parametrized(model, expected):
|
|
||||||
"""Parametrized Claude model detection with edge cases."""
|
|
||||||
assert is_claude_model(model) is expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model,default,expected",
|
|
||||||
[
|
|
||||||
("claude-3-sonnet", "target", "target"),
|
|
||||||
("anthropic/claude-3-opus", "target", "target"),
|
|
||||||
("gpt-4", "target", "gpt-4"),
|
|
||||||
("openai/gpt-3.5-turbo", "target", "openai/gpt-3.5-turbo"),
|
|
||||||
("", "target", ""), # empty string is not a claude model
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"claude_mapped",
|
|
||||||
"prefixed_claude",
|
|
||||||
"non_claude",
|
|
||||||
"prefixed_non_claude",
|
|
||||||
"empty",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_normalize_model_name_parametrized(model, default, expected):
|
|
||||||
"""Parametrized model normalization."""
|
|
||||||
assert normalize_model_name(model, default) == expected
|
|
||||||
|
|
@ -3,9 +3,9 @@
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from config.nim import NimSettings
|
from config.nim import NimSettings
|
||||||
|
from providers.common.utils import set_if_not_none
|
||||||
from providers.nvidia_nim.request import (
|
from providers.nvidia_nim.request import (
|
||||||
_set_extra,
|
_set_extra,
|
||||||
_set_if_not_none,
|
|
||||||
build_request_body,
|
build_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -13,12 +13,12 @@ from providers.nvidia_nim.request import (
|
||||||
class TestSetIfNotNone:
|
class TestSetIfNotNone:
|
||||||
def test_value_not_none_sets(self):
|
def test_value_not_none_sets(self):
|
||||||
body = {}
|
body = {}
|
||||||
_set_if_not_none(body, "key", "value")
|
set_if_not_none(body, "key", "value")
|
||||||
assert body["key"] == "value"
|
assert body["key"] == "value"
|
||||||
|
|
||||||
def test_value_none_skips(self):
|
def test_value_none_skips(self):
|
||||||
body = {}
|
body = {}
|
||||||
_set_if_not_none(body, "key", None)
|
set_if_not_none(body, "key", None)
|
||||||
assert "key" not in body
|
assert "key" not in body
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue