mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Major Refactor Part 2 with kimi-k2.5 in claude code
This commit is contained in:
parent
928e702e71
commit
6102583026
41 changed files with 520 additions and 597 deletions
|
|
@ -27,6 +27,8 @@ MESSAGING_RATE_WINDOW=1
|
|||
NVIDIA_NIM_API_KEY=""
|
||||
NVIDIA_NIM_RATE_LIMIT=40
|
||||
NVIDIA_NIM_RATE_WINDOW=60
|
||||
# Base URL is fixed to https://integrate.api.nvidia.com/v1
|
||||
# All NVIDIA_NIM_* settings are strictly validated (unknown keys will error).
|
||||
|
||||
NVIDIA_NIM_TEMPERATURE=1.0
|
||||
NVIDIA_NIM_TOP_P=1.0
|
||||
|
|
|
|||
33
README.md
33
README.md
|
|
@ -117,7 +117,6 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
|
|||
| --------------------------------- | ------------------------------- | ------------------------------------- |
|
||||
| `NVIDIA_NIM_API_KEY` | Your NVIDIA API key | required |
|
||||
| `MODEL` | Model to use for all requests | `moonshotai/kimi-k2-thinking` |
|
||||
| `NVIDIA_NIM_BASE_URL` | NIM endpoint | `https://integrate.api.nvidia.com/v1` |
|
||||
| `CLAUDE_WORKSPACE` | Directory for agent workspace | `./agent_workspace` |
|
||||
| `ALLOWED_DIR` | Allowed directories for agent | `""` |
|
||||
| `MAX_CLI_SESSIONS` | Max concurrent CLI sessions | `10` |
|
||||
|
|
@ -132,10 +131,34 @@ curl "https://integrate.api.nvidia.com/v1/models" > nvidia_nim_models.json
|
|||
| `MESSAGING_RATE_WINDOW` | Messaging window (seconds) | `1` |
|
||||
| `NVIDIA_NIM_RATE_LIMIT` | API requests per window | `40` |
|
||||
| `NVIDIA_NIM_RATE_WINDOW` | Rate limit window (seconds) | `60` |
|
||||
| `NVIDIA_NIM_TEMPERATURE` | Model temperature | `1.0` |
|
||||
| `NVIDIA_NIM_TOP_P` | Top P sampling | `1.0` |
|
||||
| `NVIDIA_NIM_TOP_K` | Top K sampling | `-1` |
|
||||
| `NVIDIA_NIM_MAX_TOKENS` | Max tokens for generation | `81920` |
|
||||
|
||||
The NVIDIA NIM base URL is fixed to `https://integrate.api.nvidia.com/v1`.
|
||||
|
||||
**NIM Settings (prefix `NVIDIA_NIM_`)**
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------------- | ------------------------------------- | --------- |
|
||||
| `NVIDIA_NIM_TEMPERATURE` | Sampling temperature | `1.0` |
|
||||
| `NVIDIA_NIM_TOP_P` | Top-p nucleus sampling | `1.0` |
|
||||
| `NVIDIA_NIM_TOP_K` | Top-k sampling | `-1` |
|
||||
| `NVIDIA_NIM_MAX_TOKENS` | Max tokens for generation | `81920` |
|
||||
| `NVIDIA_NIM_PRESENCE_PENALTY` | Presence penalty | `0.0` |
|
||||
| `NVIDIA_NIM_FREQUENCY_PENALTY` | Frequency penalty | `0.0` |
|
||||
| `NVIDIA_NIM_MIN_P` | Min-p sampling | `0.0` |
|
||||
| `NVIDIA_NIM_REPETITION_PENALTY` | Repetition penalty | `1.0` |
|
||||
| `NVIDIA_NIM_SEED` | RNG seed (blank = unset) | unset |
|
||||
| `NVIDIA_NIM_STOP` | Stop string (blank = unset) | unset |
|
||||
| `NVIDIA_NIM_PARALLEL_TOOL_CALLS` | Parallel tool calls | `true` |
|
||||
| `NVIDIA_NIM_RETURN_TOKENS_AS_TOKEN_IDS` | Return token ids | `false` |
|
||||
| `NVIDIA_NIM_INCLUDE_STOP_STR_IN_OUTPUT` | Include stop string in output | `false` |
|
||||
| `NVIDIA_NIM_IGNORE_EOS` | Ignore EOS token | `false` |
|
||||
| `NVIDIA_NIM_MIN_TOKENS` | Minimum generated tokens | `0` |
|
||||
| `NVIDIA_NIM_CHAT_TEMPLATE` | Chat template override | unset |
|
||||
| `NVIDIA_NIM_REQUEST_ID` | Request id override | unset |
|
||||
| `NVIDIA_NIM_REASONING_EFFORT` | Reasoning effort (`low|medium|high`) | `high` |
|
||||
| `NVIDIA_NIM_INCLUDE_REASONING` | Include reasoning in response | `true` |
|
||||
|
||||
All `NVIDIA_NIM_*` settings are strictly validated; unknown keys with this prefix will cause startup errors.
|
||||
|
||||
See [`.env.example`](.env.example) for all supported parameters.
|
||||
|
||||
|
|
|
|||
|
|
@ -100,9 +100,7 @@ async def lifespan(app: FastAPI):
|
|||
# Restore tree state if available
|
||||
saved_trees = session_store.get_all_trees()
|
||||
if saved_trees:
|
||||
logger.info(
|
||||
f"Restoring {len(saved_trees)} conversation trees..."
|
||||
)
|
||||
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
|
||||
from messaging.tree_queue import TreeQueueManager
|
||||
|
||||
message_handler.tree_queue = TreeQueueManager.from_dict(
|
||||
|
|
@ -124,7 +122,9 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# Start the platform
|
||||
await messaging_platform.start()
|
||||
logger.info(f"{messaging_platform.name} platform started with message handler")
|
||||
logger.info(
|
||||
f"{messaging_platform.name} platform started with message handler"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Messaging module import error: {e}")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Dependency injection for FastAPI."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from config.settings import Settings, get_settings as _get_settings, NVIDIA_NIM_BASE_URL
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
|
||||
|
|
@ -14,31 +15,6 @@ def get_settings() -> Settings:
|
|||
return _get_settings()
|
||||
|
||||
|
||||
def _build_nim_extra_params(settings: Settings) -> dict:
|
||||
"""Build NIM-specific extra_params from settings."""
|
||||
params = {
|
||||
"temperature": settings.nvidia_nim_temperature,
|
||||
"top_p": settings.nvidia_nim_top_p,
|
||||
"max_tokens": settings.nvidia_nim_max_tokens,
|
||||
}
|
||||
# Only include non-default values to avoid overriding request-level settings
|
||||
if settings.nvidia_nim_top_k != -1:
|
||||
params["top_k"] = settings.nvidia_nim_top_k
|
||||
if settings.nvidia_nim_presence_penalty != 0.0:
|
||||
params["presence_penalty"] = settings.nvidia_nim_presence_penalty
|
||||
if settings.nvidia_nim_frequency_penalty != 0.0:
|
||||
params["frequency_penalty"] = settings.nvidia_nim_frequency_penalty
|
||||
if settings.nvidia_nim_min_p != 0.0:
|
||||
params["min_p"] = settings.nvidia_nim_min_p
|
||||
if settings.nvidia_nim_repetition_penalty != 1.0:
|
||||
params["repetition_penalty"] = settings.nvidia_nim_repetition_penalty
|
||||
if settings.nvidia_nim_seed is not None:
|
||||
params["seed"] = settings.nvidia_nim_seed
|
||||
if settings.nvidia_nim_stop:
|
||||
params["stop"] = settings.nvidia_nim_stop
|
||||
return params
|
||||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the provider instance based on settings.provider_type."""
|
||||
global _provider
|
||||
|
|
@ -53,7 +29,7 @@ def get_provider() -> BaseProvider:
|
|||
base_url=NVIDIA_NIM_BASE_URL,
|
||||
rate_limit=settings.nvidia_nim_rate_limit,
|
||||
rate_window=settings.nvidia_nim_rate_window,
|
||||
extra_params=_build_nim_extra_params(settings),
|
||||
nim_settings=settings.nim,
|
||||
)
|
||||
_provider = NvidiaNimProvider(config)
|
||||
else:
|
||||
|
|
|
|||
35
api/models/__init__.py
Normal file
35
api/models/__init__.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""API models exports."""
|
||||
|
||||
from .anthropic import (
|
||||
Role,
|
||||
ContentBlockText,
|
||||
ContentBlockImage,
|
||||
ContentBlockToolUse,
|
||||
ContentBlockToolResult,
|
||||
ContentBlockThinking,
|
||||
SystemContent,
|
||||
Message,
|
||||
Tool,
|
||||
ThinkingConfig,
|
||||
MessagesRequest,
|
||||
TokenCountRequest,
|
||||
)
|
||||
from .responses import TokenCountResponse, Usage, MessagesResponse
|
||||
|
||||
__all__ = [
|
||||
"Role",
|
||||
"ContentBlockText",
|
||||
"ContentBlockImage",
|
||||
"ContentBlockToolUse",
|
||||
"ContentBlockToolResult",
|
||||
"ContentBlockThinking",
|
||||
"SystemContent",
|
||||
"Message",
|
||||
"Tool",
|
||||
"ThinkingConfig",
|
||||
"MessagesRequest",
|
||||
"TokenCountRequest",
|
||||
"TokenCountResponse",
|
||||
"Usage",
|
||||
"MessagesResponse",
|
||||
]
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
"""Pydantic models for API requests and responses."""
|
||||
"""Pydantic models for Anthropic-compatible requests."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional, Union, Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from config.settings import get_settings
|
||||
|
|
@ -88,18 +89,18 @@ class ThinkingConfig(BaseModel):
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# Request Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
model: str
|
||||
max_tokens: int
|
||||
max_tokens: Optional[int] = None
|
||||
messages: List[Message]
|
||||
system: Optional[Union[str, List[SystemContent]]] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
|
@ -142,31 +143,3 @@ class TokenCountRequest(BaseModel):
|
|||
settings = get_settings()
|
||||
# Use centralized model normalization
|
||||
return normalize_model_name(v, settings.model)
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
input_tokens: int
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
|
||||
class MessagesResponse(BaseModel):
|
||||
id: str
|
||||
model: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: List[
|
||||
Union[
|
||||
ContentBlockText, ContentBlockToolUse, ContentBlockThinking, Dict[str, Any]
|
||||
]
|
||||
]
|
||||
type: Literal["message"] = "message"
|
||||
stop_reason: Optional[
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
|
||||
] = None
|
||||
stop_sequence: Optional[str] = None
|
||||
usage: Usage
|
||||
35
api/models/responses.py
Normal file
35
api/models/responses.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Pydantic models for API responses."""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Union, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .anthropic import ContentBlockText, ContentBlockToolUse, ContentBlockThinking
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
input_tokens: int
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
|
||||
class MessagesResponse(BaseModel):
|
||||
id: str
|
||||
model: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: List[
|
||||
Union[
|
||||
ContentBlockText, ContentBlockToolUse, ContentBlockThinking, Dict[str, Any]
|
||||
]
|
||||
]
|
||||
type: Literal["message"] = "message"
|
||||
stop_reason: Optional[
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
|
||||
] = None
|
||||
stop_sequence: Optional[str] = None
|
||||
usage: Usage
|
||||
|
|
@ -10,30 +10,13 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import tiktoken
|
||||
|
||||
from .models import MessagesRequest
|
||||
from .models.anthropic import MessagesRequest
|
||||
from utils.text import extract_text_from_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def extract_text_from_content(content) -> str:
|
||||
"""Extract concatenated text from message content (str or list of content blocks).
|
||||
|
||||
Handles the common pattern of content being either a plain string
|
||||
or a list of content blocks with a .text attribute.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
text = getattr(block, "text", "")
|
||||
if text and isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def is_quota_check_request(request_data: MessagesRequest) -> bool:
|
||||
"""Check if this is a quota probe request.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,13 +6,8 @@ import uuid
|
|||
from fastapi import APIRouter, Request, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
TokenCountRequest,
|
||||
TokenCountResponse,
|
||||
Usage,
|
||||
)
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import MessagesResponse, TokenCountResponse, Usage
|
||||
from .dependencies import get_provider, get_settings
|
||||
from .request_utils import (
|
||||
is_quota_check_request,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,5 @@
|
|||
|
||||
from .session import CLISession
|
||||
from .manager import CLISessionManager
|
||||
from .parser import CLIParser
|
||||
|
||||
__all__ = ["CLISession", "CLISessionManager", "CLIParser"]
|
||||
__all__ = ["CLISession", "CLISessionManager"]
|
||||
|
|
|
|||
127
cli/parser.py
127
cli/parser.py
|
|
@ -1,127 +0,0 @@
|
|||
"""CLI event parser for Claude Code CLI output."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIParser:
|
||||
"""Helper to structure raw CLI events."""
|
||||
|
||||
@staticmethod
|
||||
def parse_event(event: Any) -> List[Dict]:
|
||||
"""
|
||||
Parse a CLI event and return a structured result.
|
||||
|
||||
Args:
|
||||
event: Raw event dictionary from CLI
|
||||
|
||||
Returns:
|
||||
List of parsed event dicts. Empty list if not recognized.
|
||||
"""
|
||||
if not isinstance(event, dict):
|
||||
return []
|
||||
|
||||
etype = event.get("type")
|
||||
results = []
|
||||
|
||||
# 1. Handle full messages (assistant or result)
|
||||
msg_obj = None
|
||||
if etype == "assistant":
|
||||
msg_obj = event.get("message")
|
||||
elif etype == "result":
|
||||
res = event.get("result")
|
||||
if isinstance(res, dict):
|
||||
msg_obj = res.get("message")
|
||||
if not msg_obj:
|
||||
msg_obj = event.get("message")
|
||||
|
||||
if msg_obj and isinstance(msg_obj, dict):
|
||||
content = msg_obj.get("content", [])
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
thinking_parts = []
|
||||
tool_calls = []
|
||||
for c in content:
|
||||
if not isinstance(c, dict):
|
||||
continue
|
||||
ctype = c.get("type")
|
||||
if ctype == "text":
|
||||
parts.append(c.get("text", ""))
|
||||
elif ctype == "thinking":
|
||||
thinking_parts.append(c.get("thinking", ""))
|
||||
elif ctype == "tool_use":
|
||||
tool_calls.append(c)
|
||||
|
||||
# Prioritize thinking first
|
||||
if thinking_parts:
|
||||
results.append(
|
||||
{"type": "thinking", "text": "\n".join(thinking_parts)}
|
||||
)
|
||||
|
||||
# Then tools or subagents
|
||||
if tool_calls:
|
||||
# Check for subagents (Task tool)
|
||||
subagents = [
|
||||
t.get("input", {}).get("description", "Subagent")
|
||||
for t in tool_calls
|
||||
if t.get("name") == "Task"
|
||||
]
|
||||
if subagents:
|
||||
results.append({"type": "subagent_start", "tasks": subagents})
|
||||
else:
|
||||
results.append({"type": "tool_start", "tools": tool_calls})
|
||||
|
||||
# Then text content if any
|
||||
if parts:
|
||||
results.append({"type": "content", "text": "".join(parts)})
|
||||
|
||||
if results:
|
||||
return results
|
||||
|
||||
# 2. Handle streaming deltas
|
||||
if etype == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
if isinstance(delta, dict):
|
||||
if delta.get("type") == "text_delta":
|
||||
return [{"type": "content", "text": delta.get("text", "")}]
|
||||
if delta.get("type") == "thinking_delta":
|
||||
return [{"type": "thinking", "text": delta.get("thinking", "")}]
|
||||
|
||||
# 3. Handle tool usage start
|
||||
if etype == "content_block_start":
|
||||
block = event.get("content_block", {})
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
if block.get("name") == "Task":
|
||||
desc = block.get("input", {}).get("description", "Subagent")
|
||||
return [{"type": "subagent_start", "tasks": [desc]}]
|
||||
return [{"type": "tool_start", "tools": [block]}]
|
||||
|
||||
# 4. Handle errors and exit
|
||||
if etype == "error":
|
||||
err = event.get("error")
|
||||
msg = err.get("message") if isinstance(err, dict) else str(err)
|
||||
logger.info(f"CLI_PARSER: Parsed error event: {msg[:100]}")
|
||||
return [{"type": "error", "message": msg}]
|
||||
elif etype == "exit":
|
||||
code = event.get("code", 0)
|
||||
stderr = event.get("stderr")
|
||||
if code == 0:
|
||||
logger.debug(f"CLI_PARSER: Successful exit (code={code})")
|
||||
return [{"type": "complete", "status": "success"}]
|
||||
else:
|
||||
# Non-zero exit is an error
|
||||
error_msg = stderr if stderr else f"Process exited with code {code}"
|
||||
logger.warning(
|
||||
f"CLI_PARSER: Error exit (code={code}): {error_msg[:100]}"
|
||||
)
|
||||
return [
|
||||
{"type": "error", "message": error_msg},
|
||||
{"type": "complete", "status": "failed"},
|
||||
]
|
||||
|
||||
# Log unrecognized events for debugging
|
||||
if etype:
|
||||
logger.debug(f"CLI_PARSER: Unrecognized event type: {etype}")
|
||||
return []
|
||||
63
config/nim.py
Normal file
63
config/nim.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""NVIDIA NIM settings (strict validation)."""
|
||||
|
||||
from typing import Optional, Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class NimSettings(BaseSettings):
|
||||
"""Strictly validated NVIDIA NIM settings."""
|
||||
|
||||
temperature: float = Field(1.0, ge=0.0)
|
||||
top_p: float = Field(1.0, ge=0.0, le=1.0)
|
||||
top_k: int = -1
|
||||
max_tokens: int = Field(81920, ge=1)
|
||||
presence_penalty: float = Field(0.0, ge=-2.0, le=2.0)
|
||||
frequency_penalty: float = Field(0.0, ge=-2.0, le=2.0)
|
||||
|
||||
min_p: float = Field(0.0, ge=0.0, le=1.0)
|
||||
repetition_penalty: float = Field(1.0, ge=0.0)
|
||||
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[str] = None
|
||||
|
||||
parallel_tool_calls: bool = True
|
||||
return_tokens_as_token_ids: bool = False
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
|
||||
min_tokens: int = Field(0, ge=0)
|
||||
chat_template: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
reasoning_effort: Literal["low", "medium", "high"] = "high"
|
||||
include_reasoning: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="NVIDIA_NIM_",
|
||||
# Rely on global load_dotenv in config.settings to avoid
|
||||
# reading unrelated .env keys into this settings model.
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@field_validator("top_k")
|
||||
@classmethod
|
||||
def validate_top_k(cls, v):
|
||||
if v < -1:
|
||||
raise ValueError("top_k must be -1 or >= 0")
|
||||
return v
|
||||
|
||||
@field_validator("seed", mode="before")
|
||||
@classmethod
|
||||
def parse_optional_int(cls, v):
|
||||
if v == "" or v is None:
|
||||
return None
|
||||
return int(v)
|
||||
|
||||
@field_validator("stop", "chat_template", "request_id", mode="before")
|
||||
@classmethod
|
||||
def parse_optional_str(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
|
@ -2,10 +2,13 @@
|
|||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from pydantic import field_validator
|
||||
|
||||
from pydantic import field_validator, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .nim import NimSettings
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Fixed base URL for NVIDIA NIM
|
||||
|
|
@ -41,33 +44,8 @@ class Settings(BaseSettings):
|
|||
enable_suggestion_mode_skip: bool = True
|
||||
enable_filepath_extraction_mock: bool = True
|
||||
|
||||
# ==================== NIM Core Parameters ====================
|
||||
nvidia_nim_temperature: float = 1.0
|
||||
nvidia_nim_top_p: float = 1.0
|
||||
nvidia_nim_top_k: int = -1
|
||||
nvidia_nim_max_tokens: int = 81920
|
||||
nvidia_nim_presence_penalty: float = 0.0
|
||||
nvidia_nim_frequency_penalty: float = 0.0
|
||||
|
||||
# ==================== NIM Advanced Parameters ====================
|
||||
nvidia_nim_min_p: float = 0.0
|
||||
nvidia_nim_repetition_penalty: float = 1.0
|
||||
nvidia_nim_seed: Optional[int] = None
|
||||
nvidia_nim_stop: Optional[str] = None
|
||||
|
||||
# ==================== NIM Flag Parameters ====================
|
||||
nvidia_nim_parallel_tool_calls: bool = True
|
||||
nvidia_nim_return_tokens_as_token_ids: bool = False
|
||||
nvidia_nim_include_stop_str_in_output: bool = False
|
||||
nvidia_nim_ignore_eos: bool = False
|
||||
|
||||
nvidia_nim_min_tokens: int = 0
|
||||
nvidia_nim_chat_template: str = ""
|
||||
nvidia_nim_request_id: str = ""
|
||||
|
||||
# ==================== Thinking/Reasoning Parameters ====================
|
||||
nvidia_nim_reasoning_effort: str = "high"
|
||||
nvidia_nim_include_reasoning: bool = True
|
||||
# ==================== NIM Settings ====================
|
||||
nim: NimSettings = Field(default_factory=NimSettings)
|
||||
|
||||
# ==================== Bot Wrapper Config ====================
|
||||
telegram_bot_token: Optional[str] = None
|
||||
|
|
@ -81,17 +59,8 @@ class Settings(BaseSettings):
|
|||
port: int = 8082
|
||||
log_file: str = "server.log"
|
||||
|
||||
# Handle empty strings for optional int fields
|
||||
@field_validator("nvidia_nim_seed", mode="before")
|
||||
@classmethod
|
||||
def parse_optional_int(cls, v):
|
||||
if v == "" or v is None:
|
||||
return None
|
||||
return int(v)
|
||||
|
||||
# Handle empty strings for optional string fields
|
||||
@field_validator(
|
||||
"nvidia_nim_stop",
|
||||
"telegram_bot_token",
|
||||
"allowed_telegram_user_id",
|
||||
mode="before",
|
||||
|
|
|
|||
|
|
@ -45,5 +45,7 @@ def create_messaging_platform(
|
|||
# from .discord import DiscordPlatform
|
||||
# return DiscordPlatform(...)
|
||||
|
||||
logger.warning(f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram'")
|
||||
logger.warning(
|
||||
f"Unknown messaging platform: '{platform_type}'. Supported: 'telegram'"
|
||||
)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -30,5 +30,3 @@ class IncomingMessage:
|
|||
def is_reply(self) -> bool:
|
||||
"""Check if this message is a reply to another message."""
|
||||
return self.reply_to_message_id is not None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,25 @@
|
|||
"""Base provider interface - extend this to implement your own provider."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, AsyncIterator, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from config.nim import NimSettings
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Configuration for a provider.
|
||||
|
||||
Base fields apply to all providers. Provider-specific parameters
|
||||
(e.g. NIM temperature, top_p) are passed via extra_params.
|
||||
(e.g. NIM temperature, top_p) are passed via nim_settings.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
rate_limit: Optional[int] = None
|
||||
rate_window: int = 60
|
||||
extra_params: Dict[str, Any] = {}
|
||||
nim_settings: NimSettings = Field(default_factory=NimSettings)
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
|
|
|
|||
|
|
@ -9,21 +9,14 @@ import json
|
|||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from utils.text import extract_text_from_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract concatenated text from message content (str or list of content blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
text = getattr(block, "text", "")
|
||||
if text and isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return ""
|
||||
def _extract_text_from_content(content: Any) -> str:
|
||||
"""Backward-compatible wrapper for tests and legacy imports."""
|
||||
return extract_text_from_content(content)
|
||||
|
||||
|
||||
def generate_request_fingerprint(messages: List[Any]) -> str:
|
||||
|
|
@ -56,7 +49,7 @@ def get_last_user_message_preview(messages: List[Any], max_len: int = 100) -> st
|
|||
"""Extract a preview of the last user message."""
|
||||
for msg in reversed(messages):
|
||||
if hasattr(msg, "role") and msg.role == "user":
|
||||
text = _extract_text_from_content(getattr(msg, "content", ""))
|
||||
text = extract_text_from_content(getattr(msg, "content", ""))
|
||||
if text:
|
||||
preview = text.replace("\n", " ").replace("\r", "")
|
||||
return preview[:max_len] + "..." if len(preview) > max_len else preview
|
||||
|
|
|
|||
|
|
@ -1,228 +0,0 @@
|
|||
"""Mixins for NVIDIA NIM provider - decoupling responsibilities.
|
||||
|
||||
This module contains focused mixins that handle specific aspects of the
|
||||
NVIDIA NIM provider functionality:
|
||||
- RequestBuilderMixin: Builds request bodies
|
||||
- ErrorMapperMixin: Maps HTTP errors to provider exceptions
|
||||
- ResponseConverterMixin: Converts responses between formats
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from .utils import AnthropicToOpenAIConverter, map_stop_reason, extract_think_content
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
RateLimitError,
|
||||
OverloadedError,
|
||||
APIError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestBuilderMixin:
|
||||
"""Mixin for building OpenAI-format request bodies.
|
||||
|
||||
Handles conversion from Anthropic request format to OpenAI format,
|
||||
including system prompts, tools, thinking mode, and NIM-specific parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._nim_params: Dict[str, Any] = {}
|
||||
|
||||
def _build_request_body(self, request_data: Any, stream: bool = False) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request.
|
||||
|
||||
Args:
|
||||
request_data: The incoming Anthropic-format request
|
||||
stream: Whether this is a streaming request
|
||||
|
||||
Returns:
|
||||
OpenAI-format request body dictionary
|
||||
"""
|
||||
messages = AnthropicToOpenAIConverter.convert_messages(request_data.messages)
|
||||
|
||||
# Add system prompt
|
||||
if request_data.system:
|
||||
system_msg = AnthropicToOpenAIConverter.convert_system_prompt(
|
||||
request_data.system
|
||||
)
|
||||
if system_msg:
|
||||
messages.insert(0, system_msg)
|
||||
|
||||
body = {
|
||||
"model": request_data.model,
|
||||
"messages": messages,
|
||||
"max_tokens": request_data.max_tokens,
|
||||
}
|
||||
|
||||
if request_data.temperature is not None:
|
||||
body["temperature"] = request_data.temperature
|
||||
if request_data.top_p is not None:
|
||||
body["top_p"] = request_data.top_p
|
||||
if request_data.stop_sequences:
|
||||
body["stop"] = request_data.stop_sequences
|
||||
if request_data.tools:
|
||||
body["tools"] = AnthropicToOpenAIConverter.convert_tools(request_data.tools)
|
||||
|
||||
# Handle non-standard parameters via extra_body
|
||||
extra_params = request_data.extra_body.copy() if request_data.extra_body else {}
|
||||
|
||||
# Handle thinking/reasoning mode
|
||||
if request_data.thinking and getattr(request_data.thinking, "enabled", True):
|
||||
extra_params.setdefault("thinking", {"type": "enabled"})
|
||||
extra_params.setdefault("reasoning_split", True)
|
||||
extra_params.setdefault(
|
||||
"chat_template_kwargs",
|
||||
{"thinking": True, "reasoning_split": True, "clear_thinking": False},
|
||||
)
|
||||
|
||||
if extra_params:
|
||||
body["extra_body"] = extra_params
|
||||
|
||||
# Apply NIM defaults
|
||||
for key, val in self._nim_params.items():
|
||||
if key not in body and key not in extra_params:
|
||||
body[key] = val
|
||||
|
||||
return body
|
||||
|
||||
|
||||
class ErrorMapperMixin:
|
||||
"""Mixin for mapping HTTP errors to provider exceptions.
|
||||
|
||||
Converts HTTP status codes and error responses to appropriate
|
||||
ProviderError subclasses for standardized error handling.
|
||||
"""
|
||||
|
||||
def _map_error(self, e: Exception) -> Exception:
|
||||
"""Map OpenAI exception to specific ProviderError.
|
||||
|
||||
Args:
|
||||
e: The OpenAI exception to map
|
||||
|
||||
Returns:
|
||||
Appropriate ProviderError subclass instance
|
||||
"""
|
||||
import openai
|
||||
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
return AuthenticationError(str(e), raw_error=str(e))
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
# Trigger global rate limit block
|
||||
from .rate_limit import GlobalRateLimiter
|
||||
|
||||
GlobalRateLimiter.get_instance().set_blocked(60) # Default 60s cooldown
|
||||
return RateLimitError(str(e), raw_error=str(e))
|
||||
if isinstance(e, openai.BadRequestError):
|
||||
return InvalidRequestError(str(e), raw_error=str(e))
|
||||
if isinstance(e, openai.InternalServerError):
|
||||
message = str(e)
|
||||
if "overloaded" in message.lower() or "capacity" in message.lower():
|
||||
return OverloadedError(message, raw_error=str(e))
|
||||
return APIError(message, status_code=500, raw_error=str(e))
|
||||
if isinstance(e, openai.APIError):
|
||||
return APIError(
|
||||
str(e), status_code=getattr(e, "status_code", 500), raw_error=str(e)
|
||||
)
|
||||
|
||||
return e
|
||||
|
||||
|
||||
class ResponseConverterMixin:
|
||||
"""Mixin for converting OpenAI responses to Anthropic format.
|
||||
|
||||
Handles content extraction, reasoning/thinking blocks, tool calls,
|
||||
and response structure transformation.
|
||||
"""
|
||||
|
||||
def convert_response(self, response_json: dict, original_request: Any) -> dict:
|
||||
"""Convert OpenAI response to Anthropic format.
|
||||
|
||||
Args:
|
||||
response_json: OpenAI-format response JSON
|
||||
original_request: Original Anthropic-format request
|
||||
|
||||
Returns:
|
||||
Anthropic-format response dictionary
|
||||
"""
|
||||
import uuid
|
||||
|
||||
choice = response_json["choices"][0]
|
||||
message = choice["message"]
|
||||
content = []
|
||||
|
||||
# Extract reasoning from various sources
|
||||
reasoning = message.get("reasoning_content")
|
||||
if not reasoning:
|
||||
reasoning_details = message.get("reasoning_details")
|
||||
if reasoning_details and isinstance(reasoning_details, list):
|
||||
reasoning = "\n".join(
|
||||
item.get("text", "")
|
||||
for item in reasoning_details
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
|
||||
if reasoning:
|
||||
content.append({"type": "thinking", "thinking": reasoning})
|
||||
|
||||
# Extract text content (with think tag handling)
|
||||
if message.get("content"):
|
||||
raw_content = message["content"]
|
||||
if isinstance(raw_content, str):
|
||||
if not reasoning:
|
||||
think_content, raw_content = extract_think_content(raw_content)
|
||||
if think_content:
|
||||
content.append({"type": "thinking", "thinking": think_content})
|
||||
if raw_content:
|
||||
content.append({"type": "text", "text": raw_content})
|
||||
elif isinstance(raw_content, list):
|
||||
for item in raw_content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
content.append(item)
|
||||
|
||||
# Extract tool calls
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
try:
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
except Exception:
|
||||
args = tc["function"].get("arguments", {})
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["function"]["name"],
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
|
||||
if not content:
|
||||
# NIM models (especially Mistral-based) often require non-empty content.
|
||||
# Adding a single space satisfies this requirement while avoiding
|
||||
# the "(no content)" display issue in Claude Code.
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
usage = response_json.get("usage", {})
|
||||
|
||||
return {
|
||||
"id": response_json.get("id", f"msg_{uuid.uuid4()}"),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": original_request.model,
|
||||
"content": content,
|
||||
"stop_reason": map_stop_reason(choice.get("finish_reason")),
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
5
providers/nvidia_nim/__init__.py
Normal file
5
providers/nvidia_nim/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""NVIDIA NIM provider package."""
|
||||
|
||||
from .client import NvidiaNimProvider
|
||||
|
||||
__all__ = ["NvidiaNimProvider"]
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
"""NVIDIA NIM provider implementation."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import uuid
|
||||
|
|
@ -5,41 +7,32 @@ from typing import Any, AsyncIterator
|
|||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .base import BaseProvider, ProviderConfig
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
from .request import build_request_body
|
||||
from .response import convert_response
|
||||
from .errors import map_error
|
||||
from .utils import (
|
||||
SSEBuilder,
|
||||
map_stop_reason,
|
||||
ThinkTagParser,
|
||||
HeuristicToolParser,
|
||||
ContentType,
|
||||
extract_reasoning_from_delta,
|
||||
)
|
||||
from .exceptions import (
|
||||
APIError,
|
||||
)
|
||||
from .nvidia_mixins import (
|
||||
RequestBuilderMixin,
|
||||
ErrorMapperMixin,
|
||||
ResponseConverterMixin,
|
||||
)
|
||||
from .rate_limit import GlobalRateLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NvidiaNimProvider(
|
||||
RequestBuilderMixin,
|
||||
ErrorMapperMixin,
|
||||
ResponseConverterMixin,
|
||||
BaseProvider,
|
||||
):
|
||||
class NvidiaNimProvider(BaseProvider):
|
||||
"""NVIDIA NIM provider using official OpenAI client."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(config)
|
||||
self._api_key = config.api_key
|
||||
self._base_url = (config.base_url or "https://integrate.api.nvidia.com/v1").rstrip("/")
|
||||
self._nim_params = config.extra_params.copy() if config.extra_params else {}
|
||||
self._base_url = (
|
||||
config.base_url or "https://integrate.api.nvidia.com/v1"
|
||||
).rstrip("/")
|
||||
self._nim_settings = config.nim_settings
|
||||
self._global_rate_limiter = GlobalRateLimiter.get_instance(
|
||||
rate_limit=config.rate_limit,
|
||||
rate_window=config.rate_window,
|
||||
|
|
@ -51,6 +44,10 @@ class NvidiaNimProvider(
|
|||
timeout=300.0,
|
||||
)
|
||||
|
||||
def _build_request_body(self, request: Any, stream: bool = False) -> dict:
|
||||
"""Internal helper for tests and shared building."""
|
||||
return build_request_body(request, self._nim_settings, stream=stream)
|
||||
|
||||
async def stream_response(
|
||||
self, request: Any, input_tokens: int = 0
|
||||
) -> AsyncIterator[str]:
|
||||
|
|
@ -163,7 +160,7 @@ class NvidiaNimProvider(
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"NIM_ERROR: {type(e).__name__}: {e}")
|
||||
mapped_e = self._map_error(e)
|
||||
mapped_e = map_error(e)
|
||||
error_occurred = True
|
||||
error_message = str(mapped_e)
|
||||
logger.info(f"NIM_STREAM: Emitting SSE error event for {type(e).__name__}")
|
||||
|
|
@ -235,11 +232,15 @@ class NvidiaNimProvider(
|
|||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**body)
|
||||
# ResponseconverterMixin expects a dict
|
||||
# Responseconverter expects a dict
|
||||
return response.model_dump()
|
||||
except Exception as e:
|
||||
logger.error(f"NIM_ERROR: {type(e).__name__}: {e}")
|
||||
raise self._map_error(e)
|
||||
raise map_error(e)
|
||||
|
||||
def convert_response(self, response_json: dict, original_request: Any) -> Any:
|
||||
"""Convert provider response to Anthropic format."""
|
||||
return convert_response(response_json, original_request)
|
||||
|
||||
def _process_tool_call(self, tc: dict, sse: Any):
|
||||
"""Process a single tool call delta and yield SSE events.
|
||||
35
providers/nvidia_nim/errors.py
Normal file
35
providers/nvidia_nim/errors.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Error mapping for NVIDIA NIM provider."""
|
||||
|
||||
import openai
|
||||
|
||||
from providers.exceptions import (
|
||||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
RateLimitError,
|
||||
OverloadedError,
|
||||
APIError,
|
||||
)
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
|
||||
def map_error(e: Exception) -> Exception:
|
||||
"""Map OpenAI exception to specific ProviderError."""
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
return AuthenticationError(str(e), raw_error=str(e))
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
# Trigger global rate limit block
|
||||
GlobalRateLimiter.get_instance().set_blocked(60) # Default 60s cooldown
|
||||
return RateLimitError(str(e), raw_error=str(e))
|
||||
if isinstance(e, openai.BadRequestError):
|
||||
return InvalidRequestError(str(e), raw_error=str(e))
|
||||
if isinstance(e, openai.InternalServerError):
|
||||
message = str(e)
|
||||
if "overloaded" in message.lower() or "capacity" in message.lower():
|
||||
return OverloadedError(message, raw_error=str(e))
|
||||
return APIError(message, status_code=500, raw_error=str(e))
|
||||
if isinstance(e, openai.APIError):
|
||||
return APIError(
|
||||
str(e), status_code=getattr(e, "status_code", 500), raw_error=str(e)
|
||||
)
|
||||
|
||||
return e
|
||||
117
providers/nvidia_nim/request.py
Normal file
117
providers/nvidia_nim/request.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""Request builder for NVIDIA NIM provider."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from config.nim import NimSettings
|
||||
from .utils.message_converter import AnthropicToOpenAIConverter
|
||||
|
||||
|
||||
def _set_if_not_none(body: Dict[str, Any], key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
body[key] = value
|
||||
|
||||
|
||||
def _set_extra(
|
||||
extra_body: Dict[str, Any], key: str, value: Any, ignore_value: Any = None
|
||||
) -> None:
|
||||
if key in extra_body:
|
||||
return
|
||||
if value is None:
|
||||
return
|
||||
if ignore_value is not None and value == ignore_value:
|
||||
return
|
||||
extra_body[key] = value
|
||||
|
||||
|
||||
def build_request_body(
|
||||
request_data: Any, nim: NimSettings, stream: bool = False
|
||||
) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request."""
|
||||
messages = AnthropicToOpenAIConverter.convert_messages(request_data.messages)
|
||||
|
||||
# Add system prompt
|
||||
system = getattr(request_data, "system", None)
|
||||
if system:
|
||||
system_msg = AnthropicToOpenAIConverter.convert_system_prompt(system)
|
||||
if system_msg:
|
||||
messages.insert(0, system_msg)
|
||||
|
||||
body: Dict[str, Any] = {
|
||||
"model": request_data.model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# max_tokens with optional cap
|
||||
max_tokens = getattr(request_data, "max_tokens", None)
|
||||
if max_tokens is None:
|
||||
max_tokens = nim.max_tokens
|
||||
elif nim.max_tokens:
|
||||
max_tokens = min(max_tokens, nim.max_tokens)
|
||||
_set_if_not_none(body, "max_tokens", max_tokens)
|
||||
|
||||
req_temperature = getattr(request_data, "temperature", None)
|
||||
temperature = req_temperature if req_temperature is not None else nim.temperature
|
||||
_set_if_not_none(body, "temperature", temperature)
|
||||
|
||||
req_top_p = getattr(request_data, "top_p", None)
|
||||
top_p = req_top_p if req_top_p is not None else nim.top_p
|
||||
_set_if_not_none(body, "top_p", top_p)
|
||||
|
||||
stop_sequences = getattr(request_data, "stop_sequences", None)
|
||||
if stop_sequences:
|
||||
body["stop"] = stop_sequences
|
||||
elif nim.stop:
|
||||
body["stop"] = nim.stop
|
||||
|
||||
tools = getattr(request_data, "tools", None)
|
||||
if tools:
|
||||
body["tools"] = AnthropicToOpenAIConverter.convert_tools(tools)
|
||||
tool_choice = getattr(request_data, "tool_choice", None)
|
||||
if tool_choice:
|
||||
body["tool_choice"] = tool_choice
|
||||
|
||||
if nim.presence_penalty != 0.0:
|
||||
body["presence_penalty"] = nim.presence_penalty
|
||||
if nim.frequency_penalty != 0.0:
|
||||
body["frequency_penalty"] = nim.frequency_penalty
|
||||
if nim.seed is not None:
|
||||
body["seed"] = nim.seed
|
||||
|
||||
body["parallel_tool_calls"] = nim.parallel_tool_calls
|
||||
|
||||
# Handle non-standard parameters via extra_body
|
||||
extra_body: Dict[str, Any] = {}
|
||||
request_extra = getattr(request_data, "extra_body", None)
|
||||
if request_extra:
|
||||
extra_body.update(request_extra)
|
||||
|
||||
# Handle thinking/reasoning mode
|
||||
thinking = getattr(request_data, "thinking", None)
|
||||
if thinking and getattr(thinking, "enabled", True):
|
||||
extra_body.setdefault("thinking", {"type": "enabled"})
|
||||
extra_body.setdefault("reasoning_split", True)
|
||||
extra_body.setdefault(
|
||||
"chat_template_kwargs",
|
||||
{"thinking": True, "reasoning_split": True, "clear_thinking": False},
|
||||
)
|
||||
|
||||
req_top_k = getattr(request_data, "top_k", None)
|
||||
top_k = req_top_k if req_top_k is not None else nim.top_k
|
||||
_set_extra(extra_body, "top_k", top_k, ignore_value=-1)
|
||||
_set_extra(extra_body, "min_p", nim.min_p, ignore_value=0.0)
|
||||
_set_extra(
|
||||
extra_body, "repetition_penalty", nim.repetition_penalty, ignore_value=1.0
|
||||
)
|
||||
_set_extra(extra_body, "min_tokens", nim.min_tokens, ignore_value=0)
|
||||
_set_extra(extra_body, "chat_template", nim.chat_template)
|
||||
_set_extra(extra_body, "request_id", nim.request_id)
|
||||
_set_extra(extra_body, "return_tokens_as_token_ids", nim.return_tokens_as_token_ids)
|
||||
_set_extra(extra_body, "include_stop_str_in_output", nim.include_stop_str_in_output)
|
||||
_set_extra(extra_body, "ignore_eos", nim.ignore_eos)
|
||||
_set_extra(extra_body, "reasoning_effort", nim.reasoning_effort)
|
||||
_set_extra(extra_body, "include_reasoning", nim.include_reasoning)
|
||||
|
||||
if extra_body:
|
||||
body["extra_body"] = extra_body
|
||||
|
||||
return body
|
||||
83
providers/nvidia_nim/response.py
Normal file
83
providers/nvidia_nim/response.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Response conversion for NVIDIA NIM provider."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from .utils import map_stop_reason, extract_think_content
|
||||
|
||||
|
||||
def convert_response(response_json: dict, original_request: Any) -> dict:
|
||||
"""Convert OpenAI response to Anthropic format."""
|
||||
choice = response_json["choices"][0]
|
||||
message = choice["message"]
|
||||
content = []
|
||||
|
||||
# Extract reasoning from various sources
|
||||
reasoning = message.get("reasoning_content")
|
||||
if not reasoning:
|
||||
reasoning_details = message.get("reasoning_details")
|
||||
if reasoning_details and isinstance(reasoning_details, list):
|
||||
reasoning = "\n".join(
|
||||
item.get("text", "")
|
||||
for item in reasoning_details
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
|
||||
if reasoning:
|
||||
content.append({"type": "thinking", "thinking": reasoning})
|
||||
|
||||
# Extract text content (with think tag handling)
|
||||
if message.get("content"):
|
||||
raw_content = message["content"]
|
||||
if isinstance(raw_content, str):
|
||||
if not reasoning:
|
||||
think_content, raw_content = extract_think_content(raw_content)
|
||||
if think_content:
|
||||
content.append({"type": "thinking", "thinking": think_content})
|
||||
if raw_content:
|
||||
content.append({"type": "text", "text": raw_content})
|
||||
elif isinstance(raw_content, list):
|
||||
for item in raw_content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
content.append(item)
|
||||
|
||||
# Extract tool calls
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
try:
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
except Exception:
|
||||
args = tc["function"].get("arguments", {})
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["function"]["name"],
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
|
||||
if not content:
|
||||
# NIM models (especially Mistral-based) often require non-empty content.
|
||||
# Adding a single space satisfies this requirement while avoiding
|
||||
# the "(no content)" display issue in Claude Code.
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
usage = response_json.get("usage", {})
|
||||
|
||||
return {
|
||||
"id": response_json.get("id", f"msg_{uuid.uuid4()}"),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": original_request.model,
|
||||
"content": content,
|
||||
"stop_reason": map_stop_reason(choice.get("finish_reason")),
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
|
@ -14,6 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from providers.base import ProviderConfig
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from config.nim import NimSettings
|
||||
from messaging.base import CLISession, SessionManagerInterface, MessagingPlatform
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.session import SessionStore
|
||||
|
|
@ -26,6 +27,7 @@ def provider_config():
|
|||
base_url="https://test.api.nvidia.com/v1",
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
nim_settings=NimSettings(),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,37 +4,34 @@ import pytest
|
|||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from messaging.event_parser import parse_cli_event
|
||||
|
||||
# --- Existing Parser Tests ---
|
||||
|
||||
|
||||
class TestCLIParser:
|
||||
"""Test CLIParser event parsing."""
|
||||
"""Test CLI event parsing."""
|
||||
|
||||
def test_parse_text_content(self):
|
||||
"""Test parsing text content from assistant message."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {"content": [{"type": "text", "text": "Hello, world!"}]},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "content"
|
||||
assert result[0]["text"] == "Hello, world!"
|
||||
|
||||
def test_parse_thinking_content(self):
|
||||
"""Test parsing thinking content."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"content": [{"type": "thinking", "thinking": "Let me think..."}]
|
||||
},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "thinking"
|
||||
assert (
|
||||
|
|
@ -44,8 +41,6 @@ class TestCLIParser:
|
|||
|
||||
def test_parse_multiple_content(self):
|
||||
"""Test parsing mixed content (thinking + tools)."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
|
|
@ -55,7 +50,7 @@ class TestCLIParser:
|
|||
]
|
||||
},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert len(result) == 2
|
||||
assert result[0]["type"] == "thinking"
|
||||
assert result[0]["text"] == "Thinking..."
|
||||
|
|
@ -63,8 +58,6 @@ class TestCLIParser:
|
|||
|
||||
def test_parse_tool_use(self):
|
||||
"""Test parsing tool use content."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
|
|
@ -77,7 +70,7 @@ class TestCLIParser:
|
|||
]
|
||||
},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "tool_start"
|
||||
assert len(result[0]["tools"]) == 1
|
||||
|
|
@ -85,54 +78,44 @@ class TestCLIParser:
|
|||
|
||||
def test_parse_text_delta(self):
|
||||
"""Test parsing streaming text delta."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "text_delta", "text": "streaming text"},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "content"
|
||||
assert result[0]["text"] == "streaming text"
|
||||
|
||||
def test_parse_thinking_delta(self):
|
||||
"""Test parsing streaming thinking delta."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "thinking_delta", "thinking": "thinking..."},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "thinking"
|
||||
assert result[0]["text"] == "thinking..."
|
||||
|
||||
def test_parse_error(self):
|
||||
"""Test parsing error event."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {"type": "error", "error": {"message": "Something went wrong"}}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert result[0]["type"] == "error"
|
||||
assert result[0]["message"] == "Something went wrong"
|
||||
|
||||
def test_parse_exit_success(self):
|
||||
"""Test parsing exit event with success."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {"type": "exit", "code": 0}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
assert result[0]["type"] == "complete"
|
||||
assert result[0]["status"] == "success"
|
||||
|
||||
def test_parse_exit_failure(self):
|
||||
"""Test parsing exit event with failure returns error then complete."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {"type": "exit", "code": 1}
|
||||
result = CLIParser.parse_event(event)
|
||||
result = parse_cli_event(event)
|
||||
# Non-zero exit now returns error first, then complete
|
||||
assert len(result) == 2
|
||||
assert result[0]["type"] == "error"
|
||||
|
|
@ -145,16 +128,12 @@ class TestCLIParser:
|
|||
|
||||
def test_parse_invalid_event(self):
|
||||
"""Test parsing returns empty list for unrecognized event."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
result = CLIParser.parse_event({"type": "unknown"})
|
||||
result = parse_cli_event({"type": "unknown"})
|
||||
assert result == []
|
||||
|
||||
def test_parse_non_dict(self):
|
||||
"""Test parsing returns empty list for non-dict input."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
result = CLIParser.parse_event("not a dict")
|
||||
result = parse_cli_event("not a dict")
|
||||
assert result == []
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class TestSettings:
|
|||
settings = Settings()
|
||||
assert isinstance(settings.nvidia_nim_rate_limit, int)
|
||||
assert isinstance(settings.nvidia_nim_rate_window, int)
|
||||
assert isinstance(settings.nim.temperature, float)
|
||||
assert isinstance(settings.fast_prefix_detection, bool)
|
||||
assert isinstance(settings.max_cli_sessions, int)
|
||||
|
||||
|
|
@ -35,9 +36,7 @@ class TestSettings:
|
|||
|
||||
# Settings should handle NVIDIA_NIM_SEED="" gracefully
|
||||
settings = Settings()
|
||||
assert settings.nvidia_nim_seed is None or isinstance(
|
||||
settings.nvidia_nim_seed, int
|
||||
)
|
||||
assert settings.nim.seed is None or isinstance(settings.nim.seed, int)
|
||||
|
||||
def test_model_setting(self):
|
||||
"""Test model setting exists and is a string."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from providers.utils.message_converter import AnthropicToOpenAIConverter
|
||||
from providers.nvidia_nim.utils.message_converter import AnthropicToOpenAIConverter
|
||||
|
||||
# --- Mock Classes ---
|
||||
|
||||
|
|
@ -265,7 +265,7 @@ def test_convert_mixed_blocks_and_types_and_roles():
|
|||
|
||||
def test_get_block_attr_defaults():
|
||||
# Test helper directly
|
||||
from providers.utils.message_converter import get_block_attr
|
||||
from providers.nvidia_nim.utils.message_converter import get_block_attr
|
||||
|
||||
assert get_block_attr({}, "missing", "default") == "default"
|
||||
assert get_block_attr(object(), "missing", "default") == "default"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from api.dependencies import get_provider, get_settings, cleanup_provider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from config.nim import NimSettings
|
||||
|
||||
|
||||
def _make_mock_settings(**overrides):
|
||||
|
|
@ -9,19 +10,9 @@ def _make_mock_settings(**overrides):
|
|||
mock = MagicMock()
|
||||
mock.provider_type = "nvidia_nim"
|
||||
mock.nvidia_nim_api_key = "test_key"
|
||||
mock.nvidia_nim_base_url = None
|
||||
mock.nvidia_nim_rate_limit = 40
|
||||
mock.nvidia_nim_rate_window = 60
|
||||
mock.nvidia_nim_temperature = 0.6
|
||||
mock.nvidia_nim_top_p = 0.95
|
||||
mock.nvidia_nim_max_tokens = 16000
|
||||
mock.nvidia_nim_top_k = -1
|
||||
mock.nvidia_nim_presence_penalty = 0.0
|
||||
mock.nvidia_nim_frequency_penalty = 0.0
|
||||
mock.nvidia_nim_min_p = 0.0
|
||||
mock.nvidia_nim_repetition_penalty = 1.0
|
||||
mock.nvidia_nim_seed = None
|
||||
mock.nvidia_nim_stop = None
|
||||
mock.nim = NimSettings()
|
||||
for key, value in overrides.items():
|
||||
setattr(mock, key, value)
|
||||
return mock
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class TestMessagingModels:
|
|||
assert msg.reply_to_message_id == "100"
|
||||
|
||||
|
||||
|
||||
class TestMessagingBase:
|
||||
"""Test MessagingPlatform ABC."""
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,9 @@ class TestCreateMessagingPlatform:
|
|||
"""Create Telegram platform when bot_token is provided."""
|
||||
mock_platform = MagicMock()
|
||||
with patch("messaging.telegram.TELEGRAM_AVAILABLE", True):
|
||||
with patch("messaging.telegram.TelegramPlatform", return_value=mock_platform):
|
||||
with patch(
|
||||
"messaging.telegram.TelegramPlatform", return_value=mock_platform
|
||||
):
|
||||
result = create_messaging_platform(
|
||||
"telegram",
|
||||
bot_token="test_token",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from unittest.mock import patch
|
||||
from api.models import MessagesRequest, TokenCountRequest, Message
|
||||
from api.models.anthropic import MessagesRequest, TokenCountRequest, Message
|
||||
from config.settings import Settings
|
||||
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ def mock_settings():
|
|||
|
||||
|
||||
def test_messages_request_map_model_claude_to_default(mock_settings):
|
||||
with patch("api.models.get_settings", return_value=mock_settings):
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = MessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
|
|
@ -24,7 +24,7 @@ def test_messages_request_map_model_claude_to_default(mock_settings):
|
|||
|
||||
|
||||
def test_messages_request_map_model_non_claude_unchanged(mock_settings):
|
||||
with patch("api.models.get_settings", return_value=mock_settings):
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = MessagesRequest(
|
||||
model="gpt-4",
|
||||
max_tokens=100,
|
||||
|
|
@ -36,7 +36,7 @@ def test_messages_request_map_model_non_claude_unchanged(mock_settings):
|
|||
|
||||
|
||||
def test_messages_request_map_model_with_provider_prefix(mock_settings):
|
||||
with patch("api.models.get_settings", return_value=mock_settings):
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = MessagesRequest(
|
||||
model="anthropic/claude-3-haiku",
|
||||
max_tokens=100,
|
||||
|
|
@ -47,7 +47,7 @@ def test_messages_request_map_model_with_provider_prefix(mock_settings):
|
|||
|
||||
|
||||
def test_token_count_request_model_validation(mock_settings):
|
||||
with patch("api.models.get_settings", return_value=mock_settings):
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = TokenCountRequest(
|
||||
model="claude-3-sonnet", messages=[Message(role="user", content="hello")]
|
||||
)
|
||||
|
|
@ -57,8 +57,8 @@ def test_token_count_request_model_validation(mock_settings):
|
|||
|
||||
def test_messages_request_model_mapping_logs(mock_settings):
|
||||
with (
|
||||
patch("api.models.get_settings", return_value=mock_settings),
|
||||
patch("api.models.logger.debug") as mock_log,
|
||||
patch("api.models.anthropic.get_settings", return_value=mock_settings),
|
||||
patch("api.models.anthropic.logger.debug") as mock_log,
|
||||
):
|
||||
MessagesRequest(
|
||||
model="claude-2.1",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
import pytest
|
||||
import json
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from providers.nvidia_nim import (
|
||||
NvidiaNimProvider,
|
||||
APIError,
|
||||
)
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from providers.exceptions import APIError
|
||||
|
||||
|
||||
# Mock data classes
|
||||
|
|
@ -41,7 +39,7 @@ class MockRequest:
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_rate_limiter():
|
||||
"""Mock the global rate limiter to prevent waiting."""
|
||||
with patch("providers.nvidia_nim.GlobalRateLimiter") as mock:
|
||||
with patch("providers.nvidia_nim.client.GlobalRateLimiter") as mock:
|
||||
instance = mock.get_instance.return_value
|
||||
instance.wait_if_blocked = AsyncMock(return_value=False)
|
||||
yield instance
|
||||
|
|
@ -50,7 +48,7 @@ def mock_rate_limiter():
|
|||
@pytest.mark.asyncio
|
||||
async def test_init(provider_config):
|
||||
"""Test provider initialization."""
|
||||
with patch("providers.nvidia_nim.AsyncOpenAI") as mock_openai:
|
||||
with patch("providers.nvidia_nim.client.AsyncOpenAI") as mock_openai:
|
||||
provider = NvidiaNimProvider(provider_config)
|
||||
assert provider._api_key == "test_key"
|
||||
assert provider._base_url == "https://test.api.nvidia.com/v1"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from providers.utils.think_parser import ThinkTagParser, ContentType
|
||||
from providers.utils.heuristic_tool_parser import HeuristicToolParser
|
||||
from providers.nvidia_nim.utils.think_parser import ThinkTagParser, ContentType
|
||||
from providers.nvidia_nim.utils.heuristic_tool_parser import HeuristicToolParser
|
||||
|
||||
|
||||
def test_think_tag_parser_basic():
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from api.request_utils import (
|
|||
is_prefix_detection_request,
|
||||
get_token_count,
|
||||
)
|
||||
from api.models import MessagesRequest, Message
|
||||
from api.models.anthropic import MessagesRequest, Message
|
||||
|
||||
|
||||
class TestQuotaCheckRequest:
|
||||
|
|
|
|||
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Shared utilities."""
|
||||
17
utils/text.py
Normal file
17
utils/text.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Shared text extraction utilities."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_text_from_content(content: Any) -> str:
|
||||
"""Extract concatenated text from message content (str or list of content blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
text = getattr(block, "text", "")
|
||||
if text and isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return ""
|
||||
Loading…
Add table
Add a link
Reference in a new issue