Major Refactor Part 2 with kimi-k2.5 in claude code

This commit is contained in:
Alishahryar1 2026-02-05 16:09:16 -08:00
parent 928e702e71
commit 6102583026
41 changed files with 520 additions and 597 deletions

View file

@ -27,6 +27,8 @@ MESSAGING_RATE_WINDOW=1
NVIDIA_NIM_API_KEY=""
NVIDIA_NIM_RATE_LIMIT=40
NVIDIA_NIM_RATE_WINDOW=60
# Base URL is fixed to https://integrate.api.nvidia.com/v1
# All NVIDIA_NIM_* settings are strictly validated (unknown keys will error).
NVIDIA_NIM_TEMPERATURE=1.0
NVIDIA_NIM_TOP_P=1.0

View file

@ -117,7 +117,6 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
| --------------------------------- | ------------------------------- | ------------------------------------- |
| `NVIDIA_NIM_API_KEY` | Your NVIDIA API key | required |
| `MODEL` | Model to use for all requests | `moonshotai/kimi-k2-thinking` |
| `NVIDIA_NIM_BASE_URL` | NIM endpoint | `https://integrate.api.nvidia.com/v1` |
| `CLAUDE_WORKSPACE` | Directory for agent workspace | `./agent_workspace` |
| `ALLOWED_DIR` | Allowed directories for agent | `""` |
| `MAX_CLI_SESSIONS` | Max concurrent CLI sessions | `10` |
@ -132,10 +131,34 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
| `MESSAGING_RATE_WINDOW` | Messaging window (seconds) | `1` |
| `NVIDIA_NIM_RATE_LIMIT` | API requests per window | `40` |
| `NVIDIA_NIM_RATE_WINDOW` | Rate limit window (seconds) | `60` |
| `NVIDIA_NIM_TEMPERATURE` | Model temperature | `1.0` |
| `NVIDIA_NIM_TOP_P` | Top P sampling | `1.0` |
| `NVIDIA_NIM_TOP_K` | Top K sampling | `-1` |
| `NVIDIA_NIM_MAX_TOKENS` | Max tokens for generation | `81920` |
The NVIDIA NIM base URL is fixed to `https://integrate.api.nvidia.com/v1`.
**NIM Settings (prefix `NVIDIA_NIM_`)**
| Variable | Description | Default |
| ------------------------------------- | ------------------------------------- | --------- |
| `NVIDIA_NIM_TEMPERATURE` | Sampling temperature | `1.0` |
| `NVIDIA_NIM_TOP_P` | Top-p nucleus sampling | `1.0` |
| `NVIDIA_NIM_TOP_K` | Top-k sampling | `-1` |
| `NVIDIA_NIM_MAX_TOKENS` | Max tokens for generation | `81920` |
| `NVIDIA_NIM_PRESENCE_PENALTY` | Presence penalty | `0.0` |
| `NVIDIA_NIM_FREQUENCY_PENALTY` | Frequency penalty | `0.0` |
| `NVIDIA_NIM_MIN_P` | Min-p sampling | `0.0` |
| `NVIDIA_NIM_REPETITION_PENALTY` | Repetition penalty | `1.0` |
| `NVIDIA_NIM_SEED` | RNG seed (blank = unset) | unset |
| `NVIDIA_NIM_STOP` | Stop string (blank = unset) | unset |
| `NVIDIA_NIM_PARALLEL_TOOL_CALLS` | Parallel tool calls | `true` |
| `NVIDIA_NIM_RETURN_TOKENS_AS_TOKEN_IDS` | Return token ids | `false` |
| `NVIDIA_NIM_INCLUDE_STOP_STR_IN_OUTPUT` | Include stop string in output | `false` |
| `NVIDIA_NIM_IGNORE_EOS` | Ignore EOS token | `false` |
| `NVIDIA_NIM_MIN_TOKENS` | Minimum generated tokens | `0` |
| `NVIDIA_NIM_CHAT_TEMPLATE` | Chat template override | unset |
| `NVIDIA_NIM_REQUEST_ID` | Request id override | unset |
| `NVIDIA_NIM_REASONING_EFFORT` | Reasoning effort (`low|medium|high`) | `high` |
| `NVIDIA_NIM_INCLUDE_REASONING` | Include reasoning in response | `true` |
All `NVIDIA_NIM_*` settings are strictly validated; unknown keys with this prefix will cause startup errors.
See [`.env.example`](.env.example) for all supported parameters.

View file

@ -100,9 +100,7 @@ async def lifespan(app: FastAPI):
# Restore tree state if available
saved_trees = session_store.get_all_trees()
if saved_trees:
logger.info(
f"Restoring {len(saved_trees)} conversation trees..."
)
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
from messaging.tree_queue import TreeQueueManager
message_handler.tree_queue = TreeQueueManager.from_dict(
@ -124,7 +122,9 @@ async def lifespan(app: FastAPI):
# Start the platform
await messaging_platform.start()
logger.info(f"{messaging_platform.name} platform started with message handler")
logger.info(
f"{messaging_platform.name} platform started with message handler"
)
except ImportError as e:
logger.warning(f"Messaging module import error: {e}")

View file

@ -1,6 +1,7 @@
"""Dependency injection for FastAPI."""
from typing import Optional
from config.settings import Settings, get_settings as _get_settings, NVIDIA_NIM_BASE_URL
from providers.base import BaseProvider, ProviderConfig
@ -14,31 +15,6 @@ def get_settings() -> Settings:
return _get_settings()
def _build_nim_extra_params(settings: Settings) -> dict:
"""Build NIM-specific extra_params from settings."""
params = {
"temperature": settings.nvidia_nim_temperature,
"top_p": settings.nvidia_nim_top_p,
"max_tokens": settings.nvidia_nim_max_tokens,
}
# Only include non-default values to avoid overriding request-level settings
if settings.nvidia_nim_top_k != -1:
params["top_k"] = settings.nvidia_nim_top_k
if settings.nvidia_nim_presence_penalty != 0.0:
params["presence_penalty"] = settings.nvidia_nim_presence_penalty
if settings.nvidia_nim_frequency_penalty != 0.0:
params["frequency_penalty"] = settings.nvidia_nim_frequency_penalty
if settings.nvidia_nim_min_p != 0.0:
params["min_p"] = settings.nvidia_nim_min_p
if settings.nvidia_nim_repetition_penalty != 1.0:
params["repetition_penalty"] = settings.nvidia_nim_repetition_penalty
if settings.nvidia_nim_seed is not None:
params["seed"] = settings.nvidia_nim_seed
if settings.nvidia_nim_stop:
params["stop"] = settings.nvidia_nim_stop
return params
def get_provider() -> BaseProvider:
"""Get or create the provider instance based on settings.provider_type."""
global _provider
@ -53,7 +29,7 @@ def get_provider() -> BaseProvider:
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.nvidia_nim_rate_limit,
rate_window=settings.nvidia_nim_rate_window,
extra_params=_build_nim_extra_params(settings),
nim_settings=settings.nim,
)
_provider = NvidiaNimProvider(config)
else:

35
api/models/__init__.py Normal file
View file

@ -0,0 +1,35 @@
"""API models exports."""
from .anthropic import (
Role,
ContentBlockText,
ContentBlockImage,
ContentBlockToolUse,
ContentBlockToolResult,
ContentBlockThinking,
SystemContent,
Message,
Tool,
ThinkingConfig,
MessagesRequest,
TokenCountRequest,
)
from .responses import TokenCountResponse, Usage, MessagesResponse
__all__ = [
"Role",
"ContentBlockText",
"ContentBlockImage",
"ContentBlockToolUse",
"ContentBlockToolResult",
"ContentBlockThinking",
"SystemContent",
"Message",
"Tool",
"ThinkingConfig",
"MessagesRequest",
"TokenCountRequest",
"TokenCountResponse",
"Usage",
"MessagesResponse",
]

View file

@ -1,8 +1,9 @@
"""Pydantic models for API requests and responses."""
"""Pydantic models for Anthropic-compatible requests."""
import logging
from enum import Enum
from typing import List, Dict, Any, Optional, Union, Literal
from pydantic import BaseModel, field_validator, model_validator
from config.settings import get_settings
@ -88,18 +89,18 @@ class ThinkingConfig(BaseModel):
# =============================================================================
# Request/Response Models
# Request Models
# =============================================================================
class MessagesRequest(BaseModel):
model: str
max_tokens: int
max_tokens: Optional[int] = None
messages: List[Message]
system: Optional[Union[str, List[SystemContent]]] = None
stop_sequences: Optional[List[str]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
@ -142,31 +143,3 @@ class TokenCountRequest(BaseModel):
settings = get_settings()
# Use centralized model normalization
return normalize_model_name(v, settings.model)
class TokenCountResponse(BaseModel):
input_tokens: int
class Usage(BaseModel):
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
class MessagesResponse(BaseModel):
id: str
model: str
role: Literal["assistant"] = "assistant"
content: List[
Union[
ContentBlockText, ContentBlockToolUse, ContentBlockThinking, Dict[str, Any]
]
]
type: Literal["message"] = "message"
stop_reason: Optional[
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
] = None
stop_sequence: Optional[str] = None
usage: Usage

35
api/models/responses.py Normal file
View file

@ -0,0 +1,35 @@
"""Pydantic models for API responses."""
from typing import List, Dict, Any, Optional, Union, Literal
from pydantic import BaseModel
from .anthropic import ContentBlockText, ContentBlockToolUse, ContentBlockThinking
class TokenCountResponse(BaseModel):
input_tokens: int
class Usage(BaseModel):
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
class MessagesResponse(BaseModel):
id: str
model: str
role: Literal["assistant"] = "assistant"
content: List[
Union[
ContentBlockText, ContentBlockToolUse, ContentBlockThinking, Dict[str, Any]
]
]
type: Literal["message"] = "message"
stop_reason: Optional[
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
] = None
stop_sequence: Optional[str] = None
usage: Usage

View file

@ -10,30 +10,13 @@ from typing import List, Optional, Tuple, Union
import tiktoken
from .models import MessagesRequest
from .models.anthropic import MessagesRequest
from utils.text import extract_text_from_content
logger = logging.getLogger(__name__)
ENCODER = tiktoken.get_encoding("cl100k_base")
def extract_text_from_content(content) -> str:
"""Extract concatenated text from message content (str or list of content blocks).
Handles the common pattern of content being either a plain string
or a list of content blocks with a .text attribute.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
text = getattr(block, "text", "")
if text and isinstance(text, str):
parts.append(text)
return "".join(parts)
return ""
def is_quota_check_request(request_data: MessagesRequest) -> bool:
"""Check if this is a quota probe request.

View file

@ -6,13 +6,8 @@ import uuid
from fastapi import APIRouter, Request, Depends, HTTPException
from fastapi.responses import StreamingResponse
from .models import (
MessagesRequest,
MessagesResponse,
TokenCountRequest,
TokenCountResponse,
Usage,
)
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import MessagesResponse, TokenCountResponse, Usage
from .dependencies import get_provider, get_settings
from .request_utils import (
is_quota_check_request,

View file

@ -2,6 +2,5 @@
from .session import CLISession
from .manager import CLISessionManager
from .parser import CLIParser
__all__ = ["CLISession", "CLISessionManager", "CLIParser"]
__all__ = ["CLISession", "CLISessionManager"]

View file

@ -1,127 +0,0 @@
"""CLI event parser for Claude Code CLI output."""
import logging
from typing import Dict, List, Any
logger = logging.getLogger(__name__)
class CLIParser:
"""Helper to structure raw CLI events."""
@staticmethod
def parse_event(event: Any) -> List[Dict]:
"""
Parse a CLI event and return a structured result.
Args:
event: Raw event dictionary from CLI
Returns:
List of parsed event dicts. Empty list if not recognized.
"""
if not isinstance(event, dict):
return []
etype = event.get("type")
results = []
# 1. Handle full messages (assistant or result)
msg_obj = None
if etype == "assistant":
msg_obj = event.get("message")
elif etype == "result":
res = event.get("result")
if isinstance(res, dict):
msg_obj = res.get("message")
if not msg_obj:
msg_obj = event.get("message")
if msg_obj and isinstance(msg_obj, dict):
content = msg_obj.get("content", [])
if isinstance(content, list):
parts = []
thinking_parts = []
tool_calls = []
for c in content:
if not isinstance(c, dict):
continue
ctype = c.get("type")
if ctype == "text":
parts.append(c.get("text", ""))
elif ctype == "thinking":
thinking_parts.append(c.get("thinking", ""))
elif ctype == "tool_use":
tool_calls.append(c)
# Prioritize thinking first
if thinking_parts:
results.append(
{"type": "thinking", "text": "\n".join(thinking_parts)}
)
# Then tools or subagents
if tool_calls:
# Check for subagents (Task tool)
subagents = [
t.get("input", {}).get("description", "Subagent")
for t in tool_calls
if t.get("name") == "Task"
]
if subagents:
results.append({"type": "subagent_start", "tasks": subagents})
else:
results.append({"type": "tool_start", "tools": tool_calls})
# Then text content if any
if parts:
results.append({"type": "content", "text": "".join(parts)})
if results:
return results
# 2. Handle streaming deltas
if etype == "content_block_delta":
delta = event.get("delta", {})
if isinstance(delta, dict):
if delta.get("type") == "text_delta":
return [{"type": "content", "text": delta.get("text", "")}]
if delta.get("type") == "thinking_delta":
return [{"type": "thinking", "text": delta.get("thinking", "")}]
# 3. Handle tool usage start
if etype == "content_block_start":
block = event.get("content_block", {})
if isinstance(block, dict) and block.get("type") == "tool_use":
if block.get("name") == "Task":
desc = block.get("input", {}).get("description", "Subagent")
return [{"type": "subagent_start", "tasks": [desc]}]
return [{"type": "tool_start", "tools": [block]}]
# 4. Handle errors and exit
if etype == "error":
err = event.get("error")
msg = err.get("message") if isinstance(err, dict) else str(err)
logger.info(f"CLI_PARSER: Parsed error event: {msg[:100]}")
return [{"type": "error", "message": msg}]
elif etype == "exit":
code = event.get("code", 0)
stderr = event.get("stderr")
if code == 0:
logger.debug(f"CLI_PARSER: Successful exit (code={code})")
return [{"type": "complete", "status": "success"}]
else:
# Non-zero exit is an error
error_msg = stderr if stderr else f"Process exited with code {code}"
logger.warning(
f"CLI_PARSER: Error exit (code={code}): {error_msg[:100]}"
)
return [
{"type": "error", "message": error_msg},
{"type": "complete", "status": "failed"},
]
# Log unrecognized events for debugging
if etype:
logger.debug(f"CLI_PARSER: Unrecognized event type: {etype}")
return []

63
config/nim.py Normal file
View file

@ -0,0 +1,63 @@
"""NVIDIA NIM settings (strict validation)."""
from typing import Optional, Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class NimSettings(BaseSettings):
"""Strictly validated NVIDIA NIM settings."""
temperature: float = Field(1.0, ge=0.0)
top_p: float = Field(1.0, ge=0.0, le=1.0)
top_k: int = -1
max_tokens: int = Field(81920, ge=1)
presence_penalty: float = Field(0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
min_p: float = Field(0.0, ge=0.0, le=1.0)
repetition_penalty: float = Field(1.0, ge=0.0)
seed: Optional[int] = None
stop: Optional[str] = None
parallel_tool_calls: bool = True
return_tokens_as_token_ids: bool = False
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = Field(0, ge=0)
chat_template: Optional[str] = None
request_id: Optional[str] = None
reasoning_effort: Literal["low", "medium", "high"] = "high"
include_reasoning: bool = True
model_config = SettingsConfigDict(
env_prefix="NVIDIA_NIM_",
# Rely on global load_dotenv in config.settings to avoid
# reading unrelated .env keys into this settings model.
extra="forbid",
)
@field_validator("top_k")
@classmethod
def validate_top_k(cls, v):
if v < -1:
raise ValueError("top_k must be -1 or >= 0")
return v
@field_validator("seed", mode="before")
@classmethod
def parse_optional_int(cls, v):
if v == "" or v is None:
return None
return int(v)
@field_validator("stop", "chat_template", "request_id", mode="before")
@classmethod
def parse_optional_str(cls, v):
if v == "":
return None
return v

View file

@ -2,10 +2,13 @@
from functools import lru_cache
from typing import Optional
from pydantic import field_validator
from pydantic import field_validator, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from dotenv import load_dotenv
from .nim import NimSettings
load_dotenv()
# Fixed base URL for NVIDIA NIM
@ -41,33 +44,8 @@ class Settings(BaseSettings):
enable_suggestion_mode_skip: bool = True
enable_filepath_extraction_mock: bool = True
# ==================== NIM Core Parameters ====================
nvidia_nim_temperature: float = 1.0
nvidia_nim_top_p: float = 1.0
nvidia_nim_top_k: int = -1
nvidia_nim_max_tokens: int = 81920
nvidia_nim_presence_penalty: float = 0.0
nvidia_nim_frequency_penalty: float = 0.0
# ==================== NIM Advanced Parameters ====================
nvidia_nim_min_p: float = 0.0
nvidia_nim_repetition_penalty: float = 1.0
nvidia_nim_seed: Optional[int] = None
nvidia_nim_stop: Optional[str] = None
# ==================== NIM Flag Parameters ====================
nvidia_nim_parallel_tool_calls: bool = True
nvidia_nim_return_tokens_as_token_ids: bool = False
nvidia_nim_include_stop_str_in_output: bool = False
nvidia_nim_ignore_eos: bool = False
nvidia_nim_min_tokens: int = 0
nvidia_nim_chat_template: str = ""
nvidia_nim_request_id: str = ""
# ==================== Thinking/Reasoning Parameters ====================
nvidia_nim_reasoning_effort: str = "high"
nvidia_nim_include_reasoning: bool = True
# ==================== NIM Settings ====================
nim: NimSettings = Field(default_factory=NimSettings)
# ==================== Bot Wrapper Config ====================
telegram_bot_token: Optional[str] = None
@ -81,17 +59,8 @@ class Settings(BaseSettings):
port: int = 8082
log_file: str = "server.log"
# Handle empty strings for optional int fields
@field_validator("nvidia_nim_seed", mode="before")
@classmethod
def parse_optional_int(cls, v):
if v == "" or v is None:
return None
return int(v)
# Handle empty strings for optional string fields
@field_validator(
"nvidia_nim_stop",
"telegram_bot_token",
"allowed_telegram_user_id",
mode="before",

View file

@ -45,5 +45,7 @@ def create_messaging_platform(
# from .discord import DiscordPlatform
# return DiscordPlatform(...)
logger.warning(f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram'")
logger.warning(
f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram'"
)
return None

View file

@ -30,5 +30,3 @@ class IncomingMessage:
def is_reply(self) -> bool:
"""Check if this message is a reply to another message."""
return self.reply_to_message_id is not None

View file

@ -1,22 +1,25 @@
"""Base provider interface - extend this to implement your own provider."""
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict, Optional
from pydantic import BaseModel
from typing import Any, AsyncIterator, Optional
from pydantic import BaseModel, Field
from config.nim import NimSettings
class ProviderConfig(BaseModel):
"""Configuration for a provider.
Base fields apply to all providers. Provider-specific parameters
(e.g. NIM temperature, top_p) are passed via extra_params.
(e.g. NIM temperature, top_p) are passed via nim_settings.
"""
api_key: str
base_url: Optional[str] = None
rate_limit: Optional[int] = None
rate_window: int = 60
extra_params: Dict[str, Any] = {}
nim_settings: NimSettings = Field(default_factory=NimSettings)
class BaseProvider(ABC):

View file

@ -9,21 +9,14 @@ import json
import logging
from typing import Any, Dict, List
from utils.text import extract_text_from_content
logger = logging.getLogger(__name__)
def _extract_text_from_content(content) -> str:
"""Extract concatenated text from message content (str or list of content blocks)."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
text = getattr(block, "text", "")
if text and isinstance(text, str):
parts.append(text)
return "".join(parts)
return ""
def _extract_text_from_content(content: Any) -> str:
"""Backward-compatible wrapper for tests and legacy imports."""
return extract_text_from_content(content)
def generate_request_fingerprint(messages: List[Any]) -> str:
@ -56,7 +49,7 @@ def get_last_user_message_preview(messages: List[Any], max_len: int = 100) -> st
"""Extract a preview of the last user message."""
for msg in reversed(messages):
if hasattr(msg, "role") and msg.role == "user":
text = _extract_text_from_content(getattr(msg, "content", ""))
text = extract_text_from_content(getattr(msg, "content", ""))
if text:
preview = text.replace("\n", " ").replace("\r", "")
return preview[:max_len] + "..." if len(preview) > max_len else preview

View file

@ -1,228 +0,0 @@
"""Mixins for NVIDIA NIM provider - decoupling responsibilities.
This module contains focused mixins that handle specific aspects of the
NVIDIA NIM provider functionality:
- RequestBuilderMixin: Builds request bodies
- ErrorMapperMixin: Maps HTTP errors to provider exceptions
- ResponseConverterMixin: Converts responses between formats
"""
import json
import logging
from typing import Any, Dict
from .utils import AnthropicToOpenAIConverter, map_stop_reason, extract_think_content
from .exceptions import (
AuthenticationError,
InvalidRequestError,
RateLimitError,
OverloadedError,
APIError,
)
logger = logging.getLogger(__name__)
class RequestBuilderMixin:
"""Mixin for building OpenAI-format request bodies.
Handles conversion from Anthropic request format to OpenAI format,
including system prompts, tools, thinking mode, and NIM-specific parameters.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._nim_params: Dict[str, Any] = {}
def _build_request_body(self, request_data: Any, stream: bool = False) -> dict:
"""Build OpenAI-format request body from Anthropic request.
Args:
request_data: The incoming Anthropic-format request
stream: Whether this is a streaming request
Returns:
OpenAI-format request body dictionary
"""
messages = AnthropicToOpenAIConverter.convert_messages(request_data.messages)
# Add system prompt
if request_data.system:
system_msg = AnthropicToOpenAIConverter.convert_system_prompt(
request_data.system
)
if system_msg:
messages.insert(0, system_msg)
body = {
"model": request_data.model,
"messages": messages,
"max_tokens": request_data.max_tokens,
}
if request_data.temperature is not None:
body["temperature"] = request_data.temperature
if request_data.top_p is not None:
body["top_p"] = request_data.top_p
if request_data.stop_sequences:
body["stop"] = request_data.stop_sequences
if request_data.tools:
body["tools"] = AnthropicToOpenAIConverter.convert_tools(request_data.tools)
# Handle non-standard parameters via extra_body
extra_params = request_data.extra_body.copy() if request_data.extra_body else {}
# Handle thinking/reasoning mode
if request_data.thinking and getattr(request_data.thinking, "enabled", True):
extra_params.setdefault("thinking", {"type": "enabled"})
extra_params.setdefault("reasoning_split", True)
extra_params.setdefault(
"chat_template_kwargs",
{"thinking": True, "reasoning_split": True, "clear_thinking": False},
)
if extra_params:
body["extra_body"] = extra_params
# Apply NIM defaults
for key, val in self._nim_params.items():
if key not in body and key not in extra_params:
body[key] = val
return body
class ErrorMapperMixin:
"""Mixin for mapping HTTP errors to provider exceptions.
Converts HTTP status codes and error responses to appropriate
ProviderError subclasses for standardized error handling.
"""
def _map_error(self, e: Exception) -> Exception:
"""Map OpenAI exception to specific ProviderError.
Args:
e: The OpenAI exception to map
Returns:
Appropriate ProviderError subclass instance
"""
import openai
if isinstance(e, openai.AuthenticationError):
return AuthenticationError(str(e), raw_error=str(e))
if isinstance(e, openai.RateLimitError):
# Trigger global rate limit block
from .rate_limit import GlobalRateLimiter
GlobalRateLimiter.get_instance().set_blocked(60) # Default 60s cooldown
return RateLimitError(str(e), raw_error=str(e))
if isinstance(e, openai.BadRequestError):
return InvalidRequestError(str(e), raw_error=str(e))
if isinstance(e, openai.InternalServerError):
message = str(e)
if "overloaded" in message.lower() or "capacity" in message.lower():
return OverloadedError(message, raw_error=str(e))
return APIError(message, status_code=500, raw_error=str(e))
if isinstance(e, openai.APIError):
return APIError(
str(e), status_code=getattr(e, "status_code", 500), raw_error=str(e)
)
return e
class ResponseConverterMixin:
"""Mixin for converting OpenAI responses to Anthropic format.
Handles content extraction, reasoning/thinking blocks, tool calls,
and response structure transformation.
"""
def convert_response(self, response_json: dict, original_request: Any) -> dict:
"""Convert OpenAI response to Anthropic format.
Args:
response_json: OpenAI-format response JSON
original_request: Original Anthropic-format request
Returns:
Anthropic-format response dictionary
"""
import uuid
choice = response_json["choices"][0]
message = choice["message"]
content = []
# Extract reasoning from various sources
reasoning = message.get("reasoning_content")
if not reasoning:
reasoning_details = message.get("reasoning_details")
if reasoning_details and isinstance(reasoning_details, list):
reasoning = "\n".join(
item.get("text", "")
for item in reasoning_details
if isinstance(item, dict)
)
if reasoning:
content.append({"type": "thinking", "thinking": reasoning})
# Extract text content (with think tag handling)
if message.get("content"):
raw_content = message["content"]
if isinstance(raw_content, str):
if not reasoning:
think_content, raw_content = extract_think_content(raw_content)
if think_content:
content.append({"type": "thinking", "thinking": think_content})
if raw_content:
content.append({"type": "text", "text": raw_content})
elif isinstance(raw_content, list):
for item in raw_content:
if isinstance(item, dict) and item.get("type") == "text":
content.append(item)
# Extract tool calls
if message.get("tool_calls"):
for tc in message["tool_calls"]:
try:
args = json.loads(tc["function"]["arguments"])
except Exception:
args = tc["function"].get("arguments", {})
content.append(
{
"type": "tool_use",
"id": tc["id"],
"name": tc["function"]["name"],
"input": args,
}
)
if not content:
# NIM models (especially Mistral-based) often require non-empty content.
# Adding a single space satisfies this requirement while avoiding
# the "(no content)" display issue in Claude Code.
content.append({"type": "text", "text": " "})
usage = response_json.get("usage", {})
return {
"id": response_json.get("id", f"msg_{uuid.uuid4()}"),
"type": "message",
"role": "assistant",
"model": original_request.model,
"content": content,
"stop_reason": map_stop_reason(choice.get("finish_reason")),
"stop_sequence": None,
"usage": {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
}

View file

@ -0,0 +1,5 @@
"""NVIDIA NIM provider package."""
from .client import NvidiaNimProvider
__all__ = ["NvidiaNimProvider"]

View file

@ -1,3 +1,5 @@
"""NVIDIA NIM provider implementation."""
import logging
import json
import uuid
@ -5,41 +7,32 @@ from typing import Any, AsyncIterator
from openai import AsyncOpenAI
from .base import BaseProvider, ProviderConfig
from providers.base import BaseProvider, ProviderConfig
from providers.rate_limit import GlobalRateLimiter
from .request import build_request_body
from .response import convert_response
from .errors import map_error
from .utils import (
SSEBuilder,
map_stop_reason,
ThinkTagParser,
HeuristicToolParser,
ContentType,
extract_reasoning_from_delta,
)
from .exceptions import (
APIError,
)
from .nvidia_mixins import (
RequestBuilderMixin,
ErrorMapperMixin,
ResponseConverterMixin,
)
from .rate_limit import GlobalRateLimiter
logger = logging.getLogger(__name__)
class NvidiaNimProvider(
RequestBuilderMixin,
ErrorMapperMixin,
ResponseConverterMixin,
BaseProvider,
):
class NvidiaNimProvider(BaseProvider):
"""NVIDIA NIM provider using official OpenAI client."""
def __init__(self, config: ProviderConfig):
super().__init__(config)
self._api_key = config.api_key
self._base_url = (config.base_url or "https://integrate.api.nvidia.com/v1").rstrip("/")
self._nim_params = config.extra_params.copy() if config.extra_params else {}
self._base_url = (
config.base_url or "https://integrate.api.nvidia.com/v1"
).rstrip("/")
self._nim_settings = config.nim_settings
self._global_rate_limiter = GlobalRateLimiter.get_instance(
rate_limit=config.rate_limit,
rate_window=config.rate_window,
@ -51,6 +44,10 @@ class NvidiaNimProvider(
timeout=300.0,
)
def _build_request_body(self, request: Any, stream: bool = False) -> dict:
"""Internal helper for tests and shared building."""
return build_request_body(request, self._nim_settings, stream=stream)
async def stream_response(
self, request: Any, input_tokens: int = 0
) -> AsyncIterator[str]:
@ -163,7 +160,7 @@ class NvidiaNimProvider(
except Exception as e:
logger.error(f"NIM_ERROR: {type(e).__name__}: {e}")
mapped_e = self._map_error(e)
mapped_e = map_error(e)
error_occurred = True
error_message = str(mapped_e)
logger.info(f"NIM_STREAM: Emitting SSE error event for {type(e).__name__}")
@ -235,11 +232,15 @@ class NvidiaNimProvider(
try:
response = await self._client.chat.completions.create(**body)
# ResponseconverterMixin expects a dict
# Responseconverter expects a dict
return response.model_dump()
except Exception as e:
logger.error(f"NIM_ERROR: {type(e).__name__}: {e}")
raise self._map_error(e)
raise map_error(e)
def convert_response(self, response_json: dict, original_request: Any) -> Any:
"""Convert provider response to Anthropic format."""
return convert_response(response_json, original_request)
def _process_tool_call(self, tc: dict, sse: Any):
"""Process a single tool call delta and yield SSE events.

View file

@ -0,0 +1,35 @@
"""Error mapping for NVIDIA NIM provider."""
import openai
from providers.exceptions import (
AuthenticationError,
InvalidRequestError,
RateLimitError,
OverloadedError,
APIError,
)
from providers.rate_limit import GlobalRateLimiter
def map_error(e: Exception) -> Exception:
"""Map OpenAI exception to specific ProviderError."""
if isinstance(e, openai.AuthenticationError):
return AuthenticationError(str(e), raw_error=str(e))
if isinstance(e, openai.RateLimitError):
# Trigger global rate limit block
GlobalRateLimiter.get_instance().set_blocked(60) # Default 60s cooldown
return RateLimitError(str(e), raw_error=str(e))
if isinstance(e, openai.BadRequestError):
return InvalidRequestError(str(e), raw_error=str(e))
if isinstance(e, openai.InternalServerError):
message = str(e)
if "overloaded" in message.lower() or "capacity" in message.lower():
return OverloadedError(message, raw_error=str(e))
return APIError(message, status_code=500, raw_error=str(e))
if isinstance(e, openai.APIError):
return APIError(
str(e), status_code=getattr(e, "status_code", 500), raw_error=str(e)
)
return e

View file

@ -0,0 +1,117 @@
"""Request builder for NVIDIA NIM provider."""
from typing import Any, Dict
from config.nim import NimSettings
from .utils.message_converter import AnthropicToOpenAIConverter
def _set_if_not_none(body: Dict[str, Any], key: str, value: Any) -> None:
if value is not None:
body[key] = value
def _set_extra(
extra_body: Dict[str, Any], key: str, value: Any, ignore_value: Any = None
) -> None:
if key in extra_body:
return
if value is None:
return
if ignore_value is not None and value == ignore_value:
return
extra_body[key] = value
def build_request_body(
request_data: Any, nim: NimSettings, stream: bool = False
) -> dict:
"""Build OpenAI-format request body from Anthropic request."""
messages = AnthropicToOpenAIConverter.convert_messages(request_data.messages)
# Add system prompt
system = getattr(request_data, "system", None)
if system:
system_msg = AnthropicToOpenAIConverter.convert_system_prompt(system)
if system_msg:
messages.insert(0, system_msg)
body: Dict[str, Any] = {
"model": request_data.model,
"messages": messages,
}
# max_tokens with optional cap
max_tokens = getattr(request_data, "max_tokens", None)
if max_tokens is None:
max_tokens = nim.max_tokens
elif nim.max_tokens:
max_tokens = min(max_tokens, nim.max_tokens)
_set_if_not_none(body, "max_tokens", max_tokens)
req_temperature = getattr(request_data, "temperature", None)
temperature = req_temperature if req_temperature is not None else nim.temperature
_set_if_not_none(body, "temperature", temperature)
req_top_p = getattr(request_data, "top_p", None)
top_p = req_top_p if req_top_p is not None else nim.top_p
_set_if_not_none(body, "top_p", top_p)
stop_sequences = getattr(request_data, "stop_sequences", None)
if stop_sequences:
body["stop"] = stop_sequences
elif nim.stop:
body["stop"] = nim.stop
tools = getattr(request_data, "tools", None)
if tools:
body["tools"] = AnthropicToOpenAIConverter.convert_tools(tools)
tool_choice = getattr(request_data, "tool_choice", None)
if tool_choice:
body["tool_choice"] = tool_choice
if nim.presence_penalty != 0.0:
body["presence_penalty"] = nim.presence_penalty
if nim.frequency_penalty != 0.0:
body["frequency_penalty"] = nim.frequency_penalty
if nim.seed is not None:
body["seed"] = nim.seed
body["parallel_tool_calls"] = nim.parallel_tool_calls
# Handle non-standard parameters via extra_body
extra_body: Dict[str, Any] = {}
request_extra = getattr(request_data, "extra_body", None)
if request_extra:
extra_body.update(request_extra)
# Handle thinking/reasoning mode
thinking = getattr(request_data, "thinking", None)
if thinking and getattr(thinking, "enabled", True):
extra_body.setdefault("thinking", {"type": "enabled"})
extra_body.setdefault("reasoning_split", True)
extra_body.setdefault(
"chat_template_kwargs",
{"thinking": True, "reasoning_split": True, "clear_thinking": False},
)
req_top_k = getattr(request_data, "top_k", None)
top_k = req_top_k if req_top_k is not None else nim.top_k
_set_extra(extra_body, "top_k", top_k, ignore_value=-1)
_set_extra(extra_body, "min_p", nim.min_p, ignore_value=0.0)
_set_extra(
extra_body, "repetition_penalty", nim.repetition_penalty, ignore_value=1.0
)
_set_extra(extra_body, "min_tokens", nim.min_tokens, ignore_value=0)
_set_extra(extra_body, "chat_template", nim.chat_template)
_set_extra(extra_body, "request_id", nim.request_id)
_set_extra(extra_body, "return_tokens_as_token_ids", nim.return_tokens_as_token_ids)
_set_extra(extra_body, "include_stop_str_in_output", nim.include_stop_str_in_output)
_set_extra(extra_body, "ignore_eos", nim.ignore_eos)
_set_extra(extra_body, "reasoning_effort", nim.reasoning_effort)
_set_extra(extra_body, "include_reasoning", nim.include_reasoning)
if extra_body:
body["extra_body"] = extra_body
return body

View file

@ -0,0 +1,83 @@
"""Response conversion for NVIDIA NIM provider."""
import json
import uuid
from typing import Any
from .utils import map_stop_reason, extract_think_content
def convert_response(response_json: dict, original_request: Any) -> dict:
"""Convert OpenAI response to Anthropic format."""
choice = response_json["choices"][0]
message = choice["message"]
content = []
# Extract reasoning from various sources
reasoning = message.get("reasoning_content")
if not reasoning:
reasoning_details = message.get("reasoning_details")
if reasoning_details and isinstance(reasoning_details, list):
reasoning = "\n".join(
item.get("text", "")
for item in reasoning_details
if isinstance(item, dict)
)
if reasoning:
content.append({"type": "thinking", "thinking": reasoning})
# Extract text content (with think tag handling)
if message.get("content"):
raw_content = message["content"]
if isinstance(raw_content, str):
if not reasoning:
think_content, raw_content = extract_think_content(raw_content)
if think_content:
content.append({"type": "thinking", "thinking": think_content})
if raw_content:
content.append({"type": "text", "text": raw_content})
elif isinstance(raw_content, list):
for item in raw_content:
if isinstance(item, dict) and item.get("type") == "text":
content.append(item)
# Extract tool calls
if message.get("tool_calls"):
for tc in message["tool_calls"]:
try:
args = json.loads(tc["function"]["arguments"])
except Exception:
args = tc["function"].get("arguments", {})
content.append(
{
"type": "tool_use",
"id": tc["id"],
"name": tc["function"]["name"],
"input": args,
}
)
if not content:
# NIM models (especially Mistral-based) often require non-empty content.
# Adding a single space satisfies this requirement while avoiding
# the "(no content)" display issue in Claude Code.
content.append({"type": "text", "text": " "})
usage = response_json.get("usage", {})
return {
"id": response_json.get("id", f"msg_{uuid.uuid4()}"),
"type": "message",
"role": "assistant",
"model": original_request.model,
"content": content,
"stop_reason": map_stop_reason(choice.get("finish_reason")),
"stop_sequence": None,
"usage": {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
}

View file

@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")
from unittest.mock import AsyncMock, MagicMock
from providers.base import ProviderConfig
from providers.nvidia_nim import NvidiaNimProvider
from config.nim import NimSettings
from messaging.base import CLISession, SessionManagerInterface, MessagingPlatform
from messaging.models import IncomingMessage
from messaging.session import SessionStore
@ -26,6 +27,7 @@ def provider_config():
base_url="https://test.api.nvidia.com/v1",
rate_limit=10,
rate_window=60,
nim_settings=NimSettings(),
)

View file

@ -4,37 +4,34 @@ import pytest
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
from messaging.event_parser import parse_cli_event
# --- Existing Parser Tests ---
class TestCLIParser:
"""Test CLIParser event parsing."""
"""Test CLI event parsing."""
def test_parse_text_content(self):
"""Test parsing text content from assistant message."""
from cli.parser import CLIParser
event = {
"type": "assistant",
"message": {"content": [{"type": "text", "text": "Hello, world!"}]},
}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert len(result) == 1
assert result[0]["type"] == "content"
assert result[0]["text"] == "Hello, world!"
def test_parse_thinking_content(self):
"""Test parsing thinking content."""
from cli.parser import CLIParser
event = {
"type": "assistant",
"message": {
"content": [{"type": "thinking", "thinking": "Let me think..."}]
},
}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert len(result) == 1
assert result[0]["type"] == "thinking"
assert (
@ -44,8 +41,6 @@ class TestCLIParser:
def test_parse_multiple_content(self):
"""Test parsing mixed content (thinking + tools)."""
from cli.parser import CLIParser
event = {
"type": "assistant",
"message": {
@ -55,7 +50,7 @@ class TestCLIParser:
]
},
}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert len(result) == 2
assert result[0]["type"] == "thinking"
assert result[0]["text"] == "Thinking..."
@ -63,8 +58,6 @@ class TestCLIParser:
def test_parse_tool_use(self):
"""Test parsing tool use content."""
from cli.parser import CLIParser
event = {
"type": "assistant",
"message": {
@ -77,7 +70,7 @@ class TestCLIParser:
]
},
}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert len(result) == 1
assert result[0]["type"] == "tool_start"
assert len(result[0]["tools"]) == 1
@ -85,54 +78,44 @@ class TestCLIParser:
def test_parse_text_delta(self):
"""Test parsing streaming text delta."""
from cli.parser import CLIParser
event = {
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "streaming text"},
}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert len(result) == 1
assert result[0]["type"] == "content"
assert result[0]["text"] == "streaming text"
def test_parse_thinking_delta(self):
"""Test parsing streaming thinking delta."""
from cli.parser import CLIParser
event = {
"type": "content_block_delta",
"delta": {"type": "thinking_delta", "thinking": "thinking..."},
}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert len(result) == 1
assert result[0]["type"] == "thinking"
assert result[0]["text"] == "thinking..."
def test_parse_error(self):
"""Test parsing error event."""
from cli.parser import CLIParser
event = {"type": "error", "error": {"message": "Something went wrong"}}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert result[0]["type"] == "error"
assert result[0]["message"] == "Something went wrong"
def test_parse_exit_success(self):
"""Test parsing exit event with success."""
from cli.parser import CLIParser
event = {"type": "exit", "code": 0}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
assert result[0]["type"] == "complete"
assert result[0]["status"] == "success"
def test_parse_exit_failure(self):
"""Test parsing exit event with failure returns error then complete."""
from cli.parser import CLIParser
event = {"type": "exit", "code": 1}
result = CLIParser.parse_event(event)
result = parse_cli_event(event)
# Non-zero exit now returns error first, then complete
assert len(result) == 2
assert result[0]["type"] == "error"
@ -145,16 +128,12 @@ class TestCLIParser:
def test_parse_invalid_event(self):
"""Test parsing returns empty list for unrecognized event."""
from cli.parser import CLIParser
result = CLIParser.parse_event({"type": "unknown"})
result = parse_cli_event({"type": "unknown"})
assert result == []
def test_parse_non_dict(self):
"""Test parsing returns empty list for non-dict input."""
from cli.parser import CLIParser
result = CLIParser.parse_event("not a dict")
result = parse_cli_event("not a dict")
assert result == []

View file

@ -18,6 +18,7 @@ class TestSettings:
settings = Settings()
assert isinstance(settings.nvidia_nim_rate_limit, int)
assert isinstance(settings.nvidia_nim_rate_window, int)
assert isinstance(settings.nim.temperature, float)
assert isinstance(settings.fast_prefix_detection, bool)
assert isinstance(settings.max_cli_sessions, int)
@ -35,9 +36,7 @@ class TestSettings:
# Settings should handle NVIDIA_NIM_SEED="" gracefully
settings = Settings()
assert settings.nvidia_nim_seed is None or isinstance(
settings.nvidia_nim_seed, int
)
assert settings.nim.seed is None or isinstance(settings.nim.seed, int)
def test_model_setting(self):
"""Test model setting exists and is a string."""

View file

@ -1,5 +1,5 @@
import json
from providers.utils.message_converter import AnthropicToOpenAIConverter
from providers.nvidia_nim.utils.message_converter import AnthropicToOpenAIConverter
# --- Mock Classes ---
@ -265,7 +265,7 @@ def test_convert_mixed_blocks_and_types_and_roles():
def test_get_block_attr_defaults():
# Test helper directly
from providers.utils.message_converter import get_block_attr
from providers.nvidia_nim.utils.message_converter import get_block_attr
assert get_block_attr({}, "missing", "default") == "default"
assert get_block_attr(object(), "missing", "default") == "default"

View file

@ -2,6 +2,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from api.dependencies import get_provider, get_settings, cleanup_provider
from providers.nvidia_nim import NvidiaNimProvider
from config.nim import NimSettings
def _make_mock_settings(**overrides):
@ -9,19 +10,9 @@ def _make_mock_settings(**overrides):
mock = MagicMock()
mock.provider_type = "nvidia_nim"
mock.nvidia_nim_api_key = "test_key"
mock.nvidia_nim_base_url = None
mock.nvidia_nim_rate_limit = 40
mock.nvidia_nim_rate_window = 60
mock.nvidia_nim_temperature = 0.6
mock.nvidia_nim_top_p = 0.95
mock.nvidia_nim_max_tokens = 16000
mock.nvidia_nim_top_k = -1
mock.nvidia_nim_presence_penalty = 0.0
mock.nvidia_nim_frequency_penalty = 0.0
mock.nvidia_nim_min_p = 0.0
mock.nvidia_nim_repetition_penalty = 1.0
mock.nvidia_nim_seed = None
mock.nvidia_nim_stop = None
mock.nim = NimSettings()
for key, value in overrides.items():
setattr(mock, key, value)
return mock

View file

@ -44,7 +44,6 @@ class TestMessagingModels:
assert msg.reply_to_message_id == "100"
class TestMessagingBase:
"""Test MessagingPlatform ABC."""

View file

@ -13,7 +13,9 @@ class TestCreateMessagingPlatform:
"""Create Telegram platform when bot_token is provided."""
mock_platform = MagicMock()
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
with patch("messaging.telegram.TelegramPlatform", return_value=mock_platform):
with patch(
"messaging.telegram.TelegramPlatform", return_value=mock_platform
):
result = create_messaging_platform(
"telegram",
bot_token="test_token",

View file

@ -1,6 +1,6 @@
import pytest
from unittest.mock import patch
from api.models import MessagesRequest, TokenCountRequest, Message
from api.models.anthropic import MessagesRequest, TokenCountRequest, Message
from config.settings import Settings
@ -12,7 +12,7 @@ def mock_settings():
def test_messages_request_map_model_claude_to_default(mock_settings):
with patch("api.models.get_settings", return_value=mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(
model="claude-3-opus",
max_tokens=100,
@ -24,7 +24,7 @@ def test_messages_request_map_model_claude_to_default(mock_settings):
def test_messages_request_map_model_non_claude_unchanged(mock_settings):
with patch("api.models.get_settings", return_value=mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(
model="gpt-4",
max_tokens=100,
@ -36,7 +36,7 @@ def test_messages_request_map_model_non_claude_unchanged(mock_settings):
def test_messages_request_map_model_with_provider_prefix(mock_settings):
with patch("api.models.get_settings", return_value=mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(
model="anthropic/claude-3-haiku",
max_tokens=100,
@ -47,7 +47,7 @@ def test_messages_request_map_model_with_provider_prefix(mock_settings):
def test_token_count_request_model_validation(mock_settings):
with patch("api.models.get_settings", return_value=mock_settings):
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = TokenCountRequest(
model="claude-3-sonnet", messages=[Message(role="user", content="hello")]
)
@ -57,8 +57,8 @@ def test_token_count_request_model_validation(mock_settings):
def test_messages_request_model_mapping_logs(mock_settings):
with (
patch("api.models.get_settings", return_value=mock_settings),
patch("api.models.logger.debug") as mock_log,
patch("api.models.anthropic.get_settings", return_value=mock_settings),
patch("api.models.anthropic.logger.debug") as mock_log,
):
MessagesRequest(
model="claude-2.1",

View file

@ -1,10 +1,8 @@
import pytest
import json
from unittest.mock import MagicMock, AsyncMock, patch
from providers.nvidia_nim import (
NvidiaNimProvider,
APIError,
)
from providers.nvidia_nim import NvidiaNimProvider
from providers.exceptions import APIError
# Mock data classes
@ -41,7 +39,7 @@ class MockRequest:
@pytest.fixture(autouse=True)
def mock_rate_limiter():
"""Mock the global rate limiter to prevent waiting."""
with patch("providers.nvidia_nim.GlobalRateLimiter") as mock:
with patch("providers.nvidia_nim.client.GlobalRateLimiter") as mock:
instance = mock.get_instance.return_value
instance.wait_if_blocked = AsyncMock(return_value=False)
yield instance
@ -50,7 +48,7 @@ def mock_rate_limiter():
@pytest.mark.asyncio
async def test_init(provider_config):
"""Test provider initialization."""
with patch("providers.nvidia_nim.AsyncOpenAI") as mock_openai:
with patch("providers.nvidia_nim.client.AsyncOpenAI") as mock_openai:
provider = NvidiaNimProvider(provider_config)
assert provider._api_key == "test_key"
assert provider._base_url == "https://test.api.nvidia.com/v1"

View file

@ -1,5 +1,5 @@
from providers.utils.think_parser import ThinkTagParser, ContentType
from providers.utils.heuristic_tool_parser import HeuristicToolParser
from providers.nvidia_nim.utils.think_parser import ThinkTagParser, ContentType
from providers.nvidia_nim.utils.heuristic_tool_parser import HeuristicToolParser
def test_think_tag_parser_basic():

View file

@ -9,7 +9,7 @@ from api.request_utils import (
is_prefix_detection_request,
get_token_count,
)
from api.models import MessagesRequest, Message
from api.models.anthropic import MessagesRequest, Message
class TestQuotaCheckRequest:

1
utils/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Shared utilities."""

17
utils/text.py Normal file
View file

@ -0,0 +1,17 @@
"""Shared text extraction utilities."""
from typing import Any
def extract_text_from_content(content: Any) -> str:
"""Extract concatenated text from message content (str or list of content blocks)."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
text = getattr(block, "text", "")
if text and isinstance(text, str):
parts.append(text)
return "".join(parts)
return ""