Backup/before cleanup 20260222 230402 (#58)

This commit is contained in:
Ali Khokhar 2026-02-27 19:50:21 -08:00 committed by GitHub
parent e2840095ce
commit c4d8681000
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 253 additions and 584 deletions

View file

@ -29,7 +29,7 @@ HTTP_CONNECT_TIMEOUT=2
# Messaging Platform: "telegram" | "discord"
MESSAGING_PLATFORM=discord
MESSAGING_PLATFORM="discord"
MESSAGING_RATE_LIMIT=1
MESSAGING_RATE_WINDOW=1

View file

@ -5,6 +5,7 @@
## CODING ENVIRONMENT
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
- Always use `uv run` to run files instead of the global `python` command.
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
- Read `.env.example` for environment variables.
- All CI checks must pass; failing checks block merge.
- Add tests for new changes (including edge cases), then run `uv run pytest`.
@ -38,7 +39,7 @@
## SUMMARY STANDARDS
- Summaries must be technical and granular.
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks].
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
## TOOLS
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.

View file

@ -5,6 +5,7 @@
## CODING ENVIRONMENT
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
- Always use `uv run` to run files instead of the global `python` command.
- Current uv ruff linter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
- Read `.env.example` for environment variables.
- All CI checks must pass; failing checks block merge.
- Add tests for new changes (including edge cases), then run `uv run pytest`.
@ -38,7 +39,7 @@
## SUMMARY STANDARDS
- Summaries must be technical and granular.
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks].
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
## TOOLS
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.

View file

@ -289,11 +289,11 @@ uv sync --extra voice
Full list in [`nvidia_nim_models.json`](nvidia_nim_models.json).
Popular models:
- `qwen/qwen3.5-397b-a17b`
- `z-ai/glm5`
- `stepfun-ai/step-3.5-flash`
- `moonshotai/kimi-k2.5`
- `minimaxai/minimax-m2.1`
- `nvidia_nim/minimaxai/minimax-m2.5`
- `nvidia_nim/qwen/qwen3.5-397b-a17b`
- `nvidia_nim/z-ai/glm5`
- `nvidia_nim/stepfun-ai/step-3.5-flash`
- `nvidia_nim/moonshotai/kimi-k2.5`
Browse: [build.nvidia.com](https://build.nvidia.com/explore/discover)
@ -310,9 +310,9 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
Hundreds of models from StepFun, OpenAI, Anthropic, Google, and more.
Popular models:
- `stepfun/step-3.5-flash:free`
- `deepseek/deepseek-r1-0528:free`
- `openai/gpt-oss-120b:free`
- `open_router/stepfun/step-3.5-flash:free`
- `open_router/deepseek/deepseek-r1-0528:free`
- `open_router/openai/gpt-oss-120b:free`
Browse: [openrouter.ai/models](https://openrouter.ai/models)
@ -385,7 +385,6 @@ free-claude-code/
├── messaging/ # MessagingPlatform ABC + Discord/Telegram bots, session management
├── config/ # Settings, NIM config, logging
├── cli/ # CLI session and process management
├── utils/ # Text utilities
└── tests/ # Pytest test suite
```

View file

@ -51,7 +51,7 @@ async def lifespan(app: FastAPI):
try:
# Use the messaging factory to create the right platform
from messaging.factory import create_messaging_platform
from messaging.platforms.factory import create_messaging_platform
messaging_platform = create_messaging_platform(
platform_type=settings.messaging_platform,

View file

@ -16,88 +16,84 @@ def get_settings() -> Settings:
return _get_settings()
def _create_provider(settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance from settings."""
if settings.provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise HTTPException(
status_code=503,
detail=(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
),
)
from providers.nvidia_nim import NvidiaNimProvider
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
elif settings.provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise HTTPException(
status_code=503,
detail=(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
),
)
from providers.open_router import OpenRouterProvider
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url="https://openrouter.ai/api/v1",
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = OpenRouterProvider(config)
elif settings.provider_type == "lmstudio":
from providers.lmstudio import LMStudioProvider
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = LMStudioProvider(config)
else:
logger.error(
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
logger.info("Provider initialized: %s", settings.provider_type)
return provider
def get_provider() -> BaseProvider:
"""Get or create the provider instance based on settings.provider_type."""
global _provider
if _provider is None:
settings = get_settings()
if settings.provider_type == "nvidia_nim":
if (
not settings.nvidia_nim_api_key
or not settings.nvidia_nim_api_key.strip()
):
raise HTTPException(
status_code=503,
detail=(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
),
)
from providers.nvidia_nim import NvidiaNimProvider
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
_provider = NvidiaNimProvider(config, nim_settings=settings.nim)
logger.info("Provider initialized: %s", settings.provider_type)
elif settings.provider_type == "open_router":
if (
not settings.open_router_api_key
or not settings.open_router_api_key.strip()
):
raise HTTPException(
status_code=503,
detail=(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
),
)
from providers.open_router import OpenRouterProvider
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url="https://openrouter.ai/api/v1",
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
_provider = OpenRouterProvider(config)
logger.info("Provider initialized: %s", settings.provider_type)
elif settings.provider_type == "lmstudio":
from providers.lmstudio import LMStudioProvider
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
_provider = LMStudioProvider(config)
logger.info("Provider initialized: %s", settings.provider_type)
else:
logger.error(
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
_provider = _create_provider(get_settings())
return _provider
@ -108,5 +104,10 @@ async def cleanup_provider():
client = getattr(_provider, "_client", None)
if client and hasattr(client, "aclose"):
await client.aclose()
elif client:
logger.warning(
"Provider client %r has no aclose(); skipping async cleanup",
type(client).__name__,
)
_provider = None
logger.debug("Provider cleanup completed")

View file

@ -29,14 +29,13 @@ def is_quota_check_request(request_data: MessagesRequest) -> bool:
def is_title_generation_request(request_data: MessagesRequest) -> bool:
"""Check if this is a conversation title generation request.
Title generation requests typically contain the phrase
"write a 5-10 word title" in the user's message.
Title generation requests are detected by a system prompt containing
title extraction instructions, no tools, and a single user message.
"""
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user":
text = extract_text_from_content(request_data.messages[-1].content)
if "write a 5-10 word title" in text.lower():
return True
return False
if not request_data.system or request_data.tools:
return False
system_text = extract_text_from_content(request_data.system).lower()
return "new conversation topic" in system_text and "title" in system_text
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:

View file

@ -7,7 +7,6 @@ from loguru import logger
from pydantic import BaseModel, field_validator, model_validator
from config.settings import get_settings
from providers.model_utils import normalize_model_name
# =============================================================================
# Content Block Types
@ -112,10 +111,7 @@ class MessagesRequest(BaseModel):
if self.original_model is None:
self.original_model = self.model
# Use centralized model normalization
normalized = normalize_model_name(self.model, settings.model_name)
if normalized != self.model:
self.model = normalized
self.model = settings.model_name
if self.model != self.original_model:
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
@ -136,5 +132,4 @@ class TokenCountRequest(BaseModel):
def validate_model_field(cls, v, info):
"""Map any Claude model name to the configured model."""
settings = get_settings()
# Use centralized model normalization
return normalize_model_name(v, settings.model_name)
return settings.model_name

View file

@ -80,6 +80,7 @@ def configure_logging(log_file: str, *, force: bool = False) -> None:
format=_serialize_with_context,
encoding="utf-8",
mode="a",
rotation="50 MB",
)
# Intercept stdlib logging: route all root logger output to loguru

View file

@ -1,9 +1,9 @@
"""Platform-agnostic messaging layer."""
from .base import CLISession, MessagingPlatform, SessionManagerInterface
from .event_parser import parse_cli_event
from .handler import ClaudeMessageHandler
from .models import IncomingMessage
from .platforms.base import CLISession, MessagingPlatform, SessionManagerInterface
from .session import SessionStore
from .trees.data import MessageNode, MessageState, MessageTree
from .trees.queue_manager import TreeQueueManager

View file

@ -1,9 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.base for new code."""
from .platforms.base import (
CLISession,
MessagingPlatform,
SessionManagerInterface,
)
__all__ = ["CLISession", "MessagingPlatform", "SessionManagerInterface"]

View file

@ -1,17 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.discord for new code."""
from .platforms.discord import (
DISCORD_AVAILABLE,
DISCORD_MESSAGE_LIMIT,
DiscordPlatform,
_get_discord,
_parse_allowed_channels,
)
__all__ = [
"DISCORD_AVAILABLE",
"DISCORD_MESSAGE_LIMIT",
"DiscordPlatform",
"_get_discord",
"_parse_allowed_channels",
]

View file

@ -1,25 +0,0 @@
"""Backward-compatible re-export. Use messaging.rendering.discord_markdown for new code."""
from .rendering.discord_markdown import (
_is_gfm_table_header_line,
_normalize_gfm_tables,
discord_bold,
discord_code_inline,
escape_discord,
escape_discord_code,
format_status,
format_status_discord,
render_markdown_to_discord,
)
__all__ = [
"_is_gfm_table_header_line",
"_normalize_gfm_tables",
"discord_bold",
"discord_code_inline",
"escape_discord",
"escape_discord_code",
"format_status",
"format_status_discord",
"render_markdown_to_discord",
]

View file

@ -1,5 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.factory for new code."""
from .platforms.factory import create_messaging_platform
__all__ = ["create_messaging_platform"]

View file

@ -95,18 +95,18 @@ class SessionStore:
except Exception as e:
logger.error(f"Failed to load sessions: {e}")
def _save(self) -> None:
"""Persist sessions and trees to disk. Caller must hold self._lock."""
try:
data = {
"trees": self._trees,
"node_to_tree": self._node_to_tree,
"message_log": self._message_log,
}
with open(self.storage_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
def _snapshot(self) -> dict:
"""Snapshot current state for serialization. Caller must hold self._lock."""
return {
"trees": dict(self._trees),
"node_to_tree": dict(self._node_to_tree),
"message_log": {k: list(v) for k, v in self._message_log.items()},
}
def _write_data(self, data: dict) -> None:
"""Write data dict to disk. Must be called WITHOUT holding self._lock."""
with open(self.storage_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
def _schedule_save(self) -> None:
"""Schedule a debounced save. Caller must hold self._lock."""
@ -126,22 +126,35 @@ class SessionStore:
if not self._dirty:
self._save_timer = None
return
self._save()
snapshot = self._snapshot()
self._dirty = False
self._save_timer = None
try:
self._write_data(snapshot)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
with self._lock:
self._dirty = True
def _flush_save(self) -> None:
"""Immediate save, cancel any pending debounced save. Caller must hold self._lock."""
def _flush_save(self) -> dict:
"""Cancel pending timer and snapshot current state. Caller must hold self._lock.
Returns snapshot dict; caller must call _write_data(snapshot) outside the lock."""
if self._save_timer is not None:
self._save_timer.cancel()
self._save_timer = None
self._dirty = False
self._save()
return self._snapshot()
def flush_pending_save(self) -> None:
"""Flush any pending debounced save. Call on shutdown to avoid losing data."""
with self._lock:
self._flush_save()
snapshot = self._flush_save()
try:
self._write_data(snapshot)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
with self._lock:
self._dirty = True
def record_message_id(
self,
@ -201,7 +214,13 @@ class SessionStore:
self._node_to_tree.clear()
self._message_log.clear()
self._message_log_ids.clear()
self._flush_save()
snapshot = self._flush_save()
try:
self._write_data(snapshot)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
with self._lock:
self._dirty = True
# ==================== Tree Methods ====================

View file

@ -1,15 +0,0 @@
"""Backward-compatible re-export. Use messaging.platforms.telegram for new code."""
from .platforms.telegram import (
TELEGRAM_AVAILABLE,
TelegramPlatform,
)
# Re-export telegram.error types when python-telegram-bot is installed
__all__ = ["TELEGRAM_AVAILABLE", "TelegramPlatform"]
try:
from telegram.error import NetworkError, RetryAfter, TelegramError
__all__ += ["NetworkError", "RetryAfter", "TelegramError"]
except ImportError:
pass

View file

@ -1,21 +0,0 @@
"""Backward-compatible re-export. Use messaging.rendering.telegram_markdown for new code."""
from .rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
escape_md_v2_link_url,
format_status,
mdv2_bold,
mdv2_code_inline,
render_markdown_to_mdv2,
)
__all__ = [
"escape_md_v2",
"escape_md_v2_code",
"escape_md_v2_link_url",
"format_status",
"mdv2_bold",
"mdv2_code_inline",
"render_markdown_to_mdv2",
]

View file

@ -105,17 +105,19 @@ class TreeRepository:
if not tree:
return []
pending = []
node = tree.get_node(node_id)
if not node:
return []
pending: list[MessageNode] = []
stack = [node_id]
for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
# Recursively get children of pending children
pending.extend(self.get_pending_children(child_id))
while stack:
current_id = stack.pop()
node = tree.get_node(current_id)
if not node:
continue
for child_id in node.children_ids:
child = tree.get_node(child_id)
if child and child.state == MessageState.PENDING:
pending.append(child)
stack.append(child_id)
return pending

View file

@ -476,13 +476,13 @@
"owned_by": "microsoft"
},
{
"id": "minimaxai/minimax-m2",
"id": "minimaxai/minimax-m2.1",
"object": "model",
"created": 735790403,
"owned_by": "minimaxai"
},
{
"id": "minimaxai/minimax-m2.1",
"id": "minimaxai/minimax-m2.5",
"object": "model",
"created": 735790403,
"owned_by": "minimaxai"
@ -709,12 +709,6 @@
"created": 735790403,
"owned_by": "nvidia"
},
{
"id": "nvidia/llama-3.2-nemoretriever-300m-embed-v2",
"object": "model",
"created": 735790403,
"owned_by": "nvidia"
},
{
"id": "nvidia/llama-3.2-nv-embedqa-1b-v1",
"object": "model",

View file

@ -9,6 +9,7 @@ from .message_converter import (
)
from .sse_builder import ContentBlockManager, SSEBuilder, map_stop_reason
from .think_parser import ContentChunk, ContentType, ThinkTagParser
from .utils import set_if_not_none
__all__ = [
"AnthropicToOpenAIConverter",
@ -22,4 +23,5 @@ __all__ = [
"get_block_type",
"map_error",
"map_stop_reason",
"set_if_not_none",
]

View file

@ -0,0 +1,9 @@
"""Shared utility helpers for provider request builders."""
from typing import Any
def set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
"""Set body[key] = value only when value is not None."""
if value is not None:
body[key] = value

View file

@ -5,15 +5,11 @@ from typing import Any
from loguru import logger
from providers.common.message_converter import AnthropicToOpenAIConverter
from providers.common.utils import set_if_not_none
LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
def build_request_body(request_data: Any) -> dict:
"""Build OpenAI-format request body from Anthropic request for LM Studio."""
logger.debug(
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
}
max_tokens = getattr(request_data, "max_tokens", None)
_set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS)
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences:

View file

@ -1,86 +0,0 @@
"""Model name normalization utilities.
Centralizes model name mapping logic to avoid duplication across the codebase.
"""
import os
# Provider prefixes to strip from model names
_PROVIDER_PREFIXES = ["anthropic/", "openai/", "gemini/"]
# Claude model identifiers
_CLAUDE_IDENTIFIERS = ["haiku", "sonnet", "opus", "claude"]
def strip_provider_prefixes(model: str) -> str:
"""
Strip provider prefixes from model name.
Args:
model: The model name, possibly with prefix
Returns:
Model name without provider prefix
"""
for prefix in _PROVIDER_PREFIXES:
if model.startswith(prefix):
return model[len(prefix) :]
return model
def is_claude_model(model: str) -> bool:
"""
Check if a model name identifies as a Claude model.
Args:
model: The (prefix-stripped) model name
Returns:
True if this is a Claude model
"""
model_lower = model.lower()
return any(name in model_lower for name in _CLAUDE_IDENTIFIERS)
def normalize_model_name(model: str, default_model: str | None = None) -> str:
"""
Normalize a model name by stripping prefixes and mapping to default if needed.
This is the central function for model name normalization across the API.
It strips provider prefixes and maps Claude model names to the configured model.
Args:
model: The model name (may include provider prefix)
default_model: The default model to use for Claude models.
If None, uses settings.model from config.
Returns:
Normalized model name (original if not a Claude model, mapped if Claude)
"""
# Strip provider prefixes
clean = strip_provider_prefixes(model)
# Map Claude models to default
if is_claude_model(clean):
if default_model is None:
# Use environment/config default
default_model = os.getenv("MODEL", "moonshotai/kimi-k2-thinking")
return default_model
return model
def get_original_model(model: str) -> str:
"""
Get the original model name, storing it before normalization.
Convenience function that returns the input unchanged, intended to be
called alongside normalize_model_name to capture the original.
Args:
model: The model name
Returns:
The model name unchanged (for documentation purposes)
"""
return model

View file

@ -6,11 +6,7 @@ from loguru import logger
from config.nim import NimSettings
from providers.common.message_converter import AnthropicToOpenAIConverter
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
from providers.common.utils import set_if_not_none
def _set_extra(
@ -52,15 +48,15 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
max_tokens = nim.max_tokens
elif nim.max_tokens:
max_tokens = min(max_tokens, nim.max_tokens)
_set_if_not_none(body, "max_tokens", max_tokens)
set_if_not_none(body, "max_tokens", max_tokens)
req_temperature = getattr(request_data, "temperature", None)
temperature = req_temperature if req_temperature is not None else nim.temperature
_set_if_not_none(body, "temperature", temperature)
set_if_not_none(body, "temperature", temperature)
req_top_p = getattr(request_data, "top_p", None)
top_p = req_top_p if req_top_p is not None else nim.top_p
_set_if_not_none(body, "top_p", top_p)
set_if_not_none(body, "top_p", top_p)
stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences:

View file

@ -5,15 +5,11 @@ from typing import Any
from loguru import logger
from providers.common.message_converter import AnthropicToOpenAIConverter
from providers.common.utils import set_if_not_none
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
def _set_if_not_none(body: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
def build_request_body(request_data: Any) -> dict:
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
logger.debug(
@ -38,10 +34,10 @@ def build_request_body(request_data: Any) -> dict:
}
max_tokens = getattr(request_data, "max_tokens", None)
_set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS)
_set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
_set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
set_if_not_none(body, "temperature", getattr(request_data, "temperature", None))
set_if_not_none(body, "top_p", getattr(request_data, "top_p", None))
stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences:

View file

@ -124,7 +124,7 @@ def test_app_lifespan_sets_state_and_cleans_up(tmp_path, messaging_enabled):
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch(
"messaging.factory.create_messaging_platform",
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform if messaging_enabled else None,
) as create_platform,
patch("messaging.session.SessionStore", return_value=session_store),
@ -195,7 +195,7 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch(
"messaging.factory.create_messaging_platform",
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform,
),
patch("messaging.session.SessionStore", return_value=session_store),
@ -234,7 +234,7 @@ def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch(
"messaging.factory.create_messaging_platform",
"messaging.platforms.factory.create_messaging_platform",
side_effect=ImportError("discord not installed"),
),
TestClient(app),
@ -284,7 +284,7 @@ def test_app_lifespan_platform_start_exception_cleanup_still_runs(tmp_path):
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch(
"messaging.factory.create_messaging_platform",
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform,
),
patch("messaging.session.SessionStore", return_value=session_store),
@ -336,7 +336,7 @@ def test_app_lifespan_flush_pending_save_exception_warning_only(tmp_path):
patch.object(api_app_mod, "get_settings", return_value=settings),
patch.object(api_app_mod, "cleanup_provider", new=cleanup_provider),
patch(
"messaging.factory.create_messaging_platform",
"messaging.platforms.factory.create_messaging_platform",
return_value=fake_platform,
),
patch("messaging.session.SessionStore", return_value=session_store),

View file

@ -25,18 +25,6 @@ def test_messages_request_map_model_claude_to_default(mock_settings):
assert request.original_model == "claude-3-opus"
def test_messages_request_map_model_non_claude_unchanged(mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(
model="gpt-4",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
# normalize_model_name returns original if not Claude
assert request.model == "gpt-4"
def test_messages_request_map_model_with_provider_prefix(mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(

View file

@ -112,68 +112,42 @@ class TestQuotaCheckRequest:
class TestTitleGenerationRequest:
"""Tests for is_title_generation_request function."""
def test_title_generation_simple(self):
"""Test title generation detection with target phrase."""
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = "Please write a 5-10 word title for this conversation"
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
assert is_title_generation_request(req) is True
def test_title_generation_case_insensitive(self):
"""Test title generation is case insensitive."""
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = "Write a 5-10 Word Title please"
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
assert is_title_generation_request(req) is True
def test_title_generation_list_content(self):
"""Test title generation with list content blocks."""
def _title_gen_system(self) -> list[MagicMock]:
block = MagicMock()
block.text = "Write a 5-10 word title"
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = [block]
block.text = "Analyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title."
return [block]
def test_title_generation_detected_via_system(self):
"""Title gen detected by system prompt containing topic/title keywords."""
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
req.system = self._title_gen_system()
req.tools = None
assert is_title_generation_request(req) is True
def test_not_title_generation_no_phrase(self):
"""Test not title generation without target phrase."""
msg = MagicMock(spec=Message)
msg.role = "user"
msg.content = "Hello world, how are you?"
def test_title_generation_not_detected_with_tools(self):
"""Not detected when tools are present (main conversation, not title gen)."""
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
req.system = self._title_gen_system()
req.tools = [MagicMock()]
assert is_title_generation_request(req) is False
def test_not_title_generation_wrong_role(self):
"""Test not title generation when last message is not from user."""
msg = MagicMock(spec=Message)
msg.role = "assistant"
msg.content = "Write a 5-10 word title"
def test_title_generation_not_detected_no_system(self):
"""Not detected when system is absent."""
req = MagicMock(spec=MessagesRequest)
req.messages = [msg]
req.system = None
req.tools = None
assert is_title_generation_request(req) is False
def test_not_title_generation_empty_messages(self):
"""Test not title generation when no messages."""
def test_title_generation_not_detected_unrelated_system(self):
"""Not detected when system prompt has no topic/title keywords."""
block = MagicMock()
block.text = "You are a helpful assistant."
req = MagicMock(spec=MessagesRequest)
req.messages = []
req.system = [block]
req.tools = None
assert is_title_generation_request(req) is False

View file

@ -18,8 +18,12 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock
from config.nim import NimSettings
from messaging.base import CLISession, MessagingPlatform, SessionManagerInterface
from messaging.models import IncomingMessage
from messaging.platforms.base import (
CLISession,
MessagingPlatform,
SessionManagerInterface,
)
from messaging.session import SessionStore
from providers.base import ProviderConfig
from providers.nvidia_nim import NvidiaNimProvider

View file

@ -1,6 +1,6 @@
"""Tests for messaging/discord_markdown.py."""
"""Tests for messaging/rendering/discord_markdown.py."""
from messaging.discord_markdown import (
from messaging.rendering.discord_markdown import (
_is_gfm_table_header_line,
_normalize_gfm_tables,
discord_bold,

View file

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from messaging.discord import (
from messaging.platforms.discord import (
DISCORD_AVAILABLE,
DiscordPlatform,
_get_discord,

View file

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
import pytest
from messaging.telegram_markdown import (
from messaging.rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
mdv2_bold,

View file

@ -4,7 +4,7 @@ import pytest
from messaging.handler import ClaudeMessageHandler
from messaging.models import IncomingMessage
from messaging.telegram_markdown import render_markdown_to_mdv2
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
from messaging.trees.data import MessageNode, MessageState

View file

@ -49,7 +49,7 @@ class TestMessagingBase:
def test_platform_is_abstract(self):
"""Verify MessagingPlatform cannot be instantiated."""
from messaging.base import MessagingPlatform
from messaging.platforms.base import MessagingPlatform
with pytest.raises(TypeError):
MessagingPlatform()

View file

@ -2,7 +2,7 @@
from unittest.mock import MagicMock, patch
from messaging.factory import create_messaging_platform
from messaging.platforms.factory import create_messaging_platform
class TestCreateMessagingPlatform:

View file

@ -3,12 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from telegram.error import NetworkError, RetryAfter, TelegramError
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
@pytest.fixture
def telegram_platform():
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
return platform
@ -76,7 +76,7 @@ async def test_telegram_no_retry_on_bad_request(telegram_platform):
def test_handler_build_message_hardening():
# Formatting hardening now lives in TranscriptBuffer rendering.
from messaging.telegram_markdown import (
from messaging.rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
mdv2_bold,
@ -112,7 +112,7 @@ def test_handler_build_message_hardening():
def test_render_output_never_exceeds_4096():
"""Transcript render with various status lengths never exceeds Telegram 4096 limit."""
from messaging.telegram_markdown import (
from messaging.rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
mdv2_bold,

View file

@ -2,7 +2,7 @@ from unittest.mock import MagicMock
import pytest
from messaging.telegram_markdown import (
from messaging.rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
mdv2_bold,
@ -74,7 +74,7 @@ def test_empty_components_with_status(handler):
def test_render_markdown_unclosed_markdown():
"""Malformed markdown (e.g. unclosed *) does not crash and produces acceptable output."""
from messaging.telegram_markdown import render_markdown_to_mdv2
from messaging.rendering.telegram_markdown import render_markdown_to_mdv2
md = "*bold without close"
out = render_markdown_to_mdv2(md)
@ -84,7 +84,7 @@ def test_render_markdown_unclosed_markdown():
def test_escape_md_v2_unicode_emoji():
"""Unicode and emoji pass through correctly (no special char escaping needed)."""
from messaging.telegram_markdown import escape_md_v2, escape_md_v2_code
from messaging.rendering.telegram_markdown import escape_md_v2, escape_md_v2_code
text = "Hello 世界 🎉 café"
assert escape_md_v2(text) == text

View file

@ -81,11 +81,13 @@ class TestSessionStoreSaveEdgeCases:
"""Tests for save failure handling."""
def test_save_io_error_handled(self, tmp_store):
"""Write failure in _save() is logged but doesn't raise."""
"""Write failure in _write_data() raises (callers handle the error)."""
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
with patch("builtins.open", side_effect=OSError("disk full")):
tmp_store._save()
# Should not raise
with (
patch("builtins.open", side_effect=OSError("disk full")),
pytest.raises(OSError),
):
tmp_store._write_data(tmp_store._snapshot())
class TestSessionStoreClearAll:

View file

@ -2,12 +2,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
@pytest.fixture
def telegram_platform():
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
platform = TelegramPlatform(bot_token="test_token", allowed_user_id="12345")
return platform

View file

@ -3,11 +3,12 @@ from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from telegram.error import NetworkError, RetryAfter, TelegramError
def test_telegram_platform_init_raises_when_dependency_missing():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", False):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
with pytest.raises(ImportError):
TelegramPlatform(bot_token="x")
@ -19,7 +20,7 @@ async def test_telegram_platform_start_requires_token():
patch.dict("os.environ", {}, clear=True),
patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True),
):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token=None)
with pytest.raises(ValueError):
@ -29,7 +30,7 @@ async def test_telegram_platform_start_requires_token():
@pytest.mark.asyncio
async def test_telegram_platform_stop_no_application_is_noop():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
platform._application = None
@ -41,7 +42,7 @@ async def test_telegram_platform_stop_no_application_is_noop():
@pytest.mark.asyncio
async def test_with_retry_returns_none_when_message_not_modified_network_error():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import NetworkError, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
@ -54,7 +55,7 @@ async def test_with_retry_returns_none_when_message_not_modified_network_error()
@pytest.mark.asyncio
async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import NetworkError, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
@ -75,7 +76,7 @@ async def test_with_retry_retries_network_error_then_succeeds(monkeypatch):
@pytest.mark.asyncio
async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import RetryAfter, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
@ -96,7 +97,7 @@ async def test_with_retry_honors_retry_after_timedelta(monkeypatch):
@pytest.mark.asyncio
async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramError, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
@ -115,7 +116,7 @@ async def test_with_retry_drops_parse_mode_on_markdown_entity_error():
@pytest.mark.asyncio
async def test_queue_send_message_without_limiter_calls_send_message():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
platform._limiter = None
@ -130,7 +131,7 @@ async def test_queue_send_message_without_limiter_calls_send_message():
@pytest.mark.asyncio
async def test_queue_edit_message_without_limiter_calls_edit_message():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
platform._limiter = None
@ -143,7 +144,7 @@ async def test_queue_edit_message_without_limiter_calls_edit_message():
def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
@ -157,7 +158,7 @@ def test_fire_and_forget_non_coroutine_uses_ensure_future(monkeypatch):
@pytest.mark.asyncio
async def test_on_start_command_replies_and_forwards():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
with patch.object(
@ -174,7 +175,7 @@ async def test_on_start_command_replies_and_forwards():
@pytest.mark.asyncio
async def test_on_telegram_message_handler_error_sends_error_message():
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
with patch.object(
@ -200,7 +201,7 @@ async def test_on_telegram_message_handler_error_sends_error_message():
@pytest.mark.asyncio
async def test_telegram_start_retries_on_network_error(monkeypatch):
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import NetworkError, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="token", allowed_user_id=None)
@ -225,7 +226,7 @@ async def test_telegram_start_retries_on_network_error(monkeypatch):
async def test_edit_message_with_text_exceeding_4096_raises():
"""edit_message with text > 4096 raises TelegramError (BadRequest)."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramError, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
platform._application = MagicMock()
@ -242,7 +243,7 @@ async def test_edit_message_with_text_exceeding_4096_raises():
async def test_edit_message_empty_string():
"""edit_message with empty string - Telegram accepts (no-op edit)."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
platform._application = MagicMock()
@ -259,7 +260,7 @@ async def test_edit_message_empty_string():
async def test_send_message_empty_string():
"""send_message with empty string - Telegram may reject; we pass through."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")
platform._application = MagicMock()
@ -277,7 +278,7 @@ async def test_send_message_empty_string():
async def test_on_telegram_message_non_text_update_ignored():
"""Update with message.photo but no text returns early without calling handler."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t", allowed_user_id="123")
handler = AsyncMock()
@ -299,7 +300,7 @@ async def test_on_telegram_message_non_text_update_ignored():
async def test_with_retry_message_not_found_returns_none():
"""'message to edit not found' returns None without retry."""
with patch("messaging.platforms.telegram.TELEGRAM_AVAILABLE", True):
from messaging.telegram import TelegramError, TelegramPlatform
from messaging.platforms.telegram import TelegramPlatform
platform = TelegramPlatform(bot_token="t")

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
from messaging.telegram_markdown import (
from messaging.rendering.telegram_markdown import (
escape_md_v2,
escape_md_v2_code,
mdv2_bold,

View file

@ -1,133 +0,0 @@
import pytest
from providers.model_utils import (
get_original_model,
is_claude_model,
normalize_model_name,
strip_provider_prefixes,
)
def test_strip_provider_prefixes():
assert strip_provider_prefixes("anthropic/claude-3") == "claude-3"
assert strip_provider_prefixes("openai/gpt-4") == "gpt-4"
assert strip_provider_prefixes("gemini/gemini-pro") == "gemini-pro"
assert strip_provider_prefixes("no-prefix") == "no-prefix"
def test_is_claude_model():
assert is_claude_model("claude-3-sonnet") is True
assert is_claude_model("claude-3-opus") is True
assert is_claude_model("claude-3-haiku") is True
assert is_claude_model("claude-2.1") is True
assert is_claude_model("gpt-4") is False
assert is_claude_model("gemini-pro") is False
def test_normalize_model_name_claude_maps_to_default():
default = "target-model"
# Strips prefix AND maps to default
assert normalize_model_name("anthropic/claude-3-sonnet", default) == default
assert normalize_model_name("claude-3-opus", default) == default
def test_normalize_model_name_non_claude_unchanged():
default = "target-model"
assert normalize_model_name("gpt-4", default) == "gpt-4"
assert (
normalize_model_name("openai/gpt-3.5-turbo", default) == "openai/gpt-3.5-turbo"
)
def test_get_original_model():
assert get_original_model("any-model") == "any-model"
def test_normalize_model_name_without_default(monkeypatch):
monkeypatch.setenv("MODEL", "env-default-model")
assert normalize_model_name("claude-3") == "env-default-model"
# --- Parametrized Edge Case Tests ---
@pytest.mark.parametrize(
"model,expected",
[
("anthropic/claude-3", "claude-3"),
("openai/gpt-4", "gpt-4"),
("gemini/gemini-pro", "gemini-pro"),
("no-prefix", "no-prefix"),
("", ""),
("anthropic/", ""),
("anthropic/openai/nested", "openai/nested"),
],
ids=[
"anthropic",
"openai",
"gemini",
"no_prefix",
"empty_string",
"prefix_only",
"nested_prefix",
],
)
def test_strip_provider_prefixes_parametrized(model, expected):
"""Parametrized prefix stripping with edge cases."""
assert strip_provider_prefixes(model) == expected
@pytest.mark.parametrize(
"model,expected",
[
("claude-3-sonnet", True),
("claude-3-opus", True),
("claude-3-haiku", True),
("claude-2.1", True),
("gpt-4", False),
("gemini-pro", False),
("", False),
("my-claude-wrapper", True), # "claude" as substring
("CLAUDE-3-SONNET", True), # case insensitive
("sonnet-v2", True), # "sonnet" identifier without "claude"
("haiku-model", True), # "haiku" identifier
],
ids=[
"sonnet",
"opus",
"haiku",
"claude2",
"gpt4",
"gemini",
"empty",
"claude_substring",
"uppercase",
"sonnet_standalone",
"haiku_standalone",
],
)
def test_is_claude_model_parametrized(model, expected):
"""Parametrized Claude model detection with edge cases."""
assert is_claude_model(model) is expected
@pytest.mark.parametrize(
"model,default,expected",
[
("claude-3-sonnet", "target", "target"),
("anthropic/claude-3-opus", "target", "target"),
("gpt-4", "target", "gpt-4"),
("openai/gpt-3.5-turbo", "target", "openai/gpt-3.5-turbo"),
("", "target", ""), # empty string is not a claude model
],
ids=[
"claude_mapped",
"prefixed_claude",
"non_claude",
"prefixed_non_claude",
"empty",
],
)
def test_normalize_model_name_parametrized(model, default, expected):
"""Parametrized model normalization."""
assert normalize_model_name(model, default) == expected

View file

@ -3,9 +3,9 @@
from unittest.mock import MagicMock
from config.nim import NimSettings
from providers.common.utils import set_if_not_none
from providers.nvidia_nim.request import (
_set_extra,
_set_if_not_none,
build_request_body,
)
@ -13,12 +13,12 @@ from providers.nvidia_nim.request import (
class TestSetIfNotNone:
def test_value_not_none_sets(self):
body = {}
_set_if_not_none(body, "key", "value")
set_if_not_none(body, "key", "value")
assert body["key"] == "value"
def test_value_none_skips(self):
body = {}
_set_if_not_none(body, "key", None)
set_if_not_none(body, "key", None)
assert "key" not in body