mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Fixes for issue 113 and 116
This commit is contained in:
parent
7468f53ab7
commit
835d0454e8
28 changed files with 807 additions and 83 deletions
10
.env.example
10
.env.example
|
|
@ -23,10 +23,10 @@ MODEL_HAIKU="open_router/stepfun/step-3.5-flash:free"
|
|||
MODEL="nvidia_nim/z-ai/glm4.7"
|
||||
|
||||
|
||||
# NIM Settings
|
||||
# Enable chat_template_kwargs + reasoning_budget for thinking models (kimi, nemotron).
|
||||
# Leave false for models that don't support it (e.g. Mistral).
|
||||
NIM_ENABLE_THINKING=false
|
||||
# Thinking output
|
||||
# Global switch for provider reasoning requests and Claude thinking blocks.
|
||||
# Set false to suppress thinking across NIM, OpenRouter, LM Studio, and llama.cpp.
|
||||
ENABLE_THINKING=true
|
||||
|
||||
|
||||
# Provider config
|
||||
|
|
@ -82,4 +82,4 @@ FAST_PREFIX_DETECTION=true
|
|||
ENABLE_NETWORK_PROBE_MOCK=true
|
||||
ENABLE_TITLE_GENERATION_SKIP=true
|
||||
ENABLE_SUGGESTION_MODE_SKIP=true
|
||||
ENABLE_FILEPATH_EXTRACTION_MOCK=true
|
||||
ENABLE_FILEPATH_EXTRACTION_MOCK=true
|
||||
|
|
|
|||
14
README.md
14
README.md
|
|
@ -80,8 +80,8 @@ MODEL_SONNET="nvidia_nim/moonshotai/kimi-k2-thinking"
|
|||
MODEL_HAIKU="nvidia_nim/stepfun-ai/step-3.5-flash"
|
||||
MODEL="nvidia_nim/z-ai/glm4.7" # fallback
|
||||
|
||||
# Enable for thinking models (kimi, nemotron). Leave false for others (e.g. Mistral).
|
||||
NIM_ENABLE_THINKING=true
|
||||
# Global switch for provider reasoning requests and Claude thinking blocks.
|
||||
ENABLE_THINKING=true
|
||||
```
|
||||
|
||||
</details>
|
||||
|
|
@ -143,6 +143,8 @@ MODEL="nvidia_nim/z-ai/glm4.7" # fallback
|
|||
|
||||
</details>
|
||||
|
||||
> Migration: `NIM_ENABLE_THINKING` was removed in this release. Rename it to `ENABLE_THINKING`.
|
||||
|
||||
<details>
|
||||
<summary><b>Optional Authentication</b> (restrict access to your proxy)</summary>
|
||||
|
||||
|
|
@ -184,6 +186,8 @@ uv run uvicorn server:app --host 0.0.0.0 --port 8082
|
|||
|
||||
**Terminal 2:** Run Claude Code:
|
||||
|
||||
Point `ANTHROPIC_BASE_URL` at the proxy root URL, not `http://localhost:8082/v1`.
|
||||
|
||||
#### Powershell
|
||||
```powershell
|
||||
$env:ANTHROPIC_AUTH_TOKEN="freecc"; $env:ANTHROPIC_BASE_URL="http://localhost:8082"; claude
|
||||
|
|
@ -277,7 +281,9 @@ free-claude-code # starts the server
|
|||
- **Per-model routing**: Opus / Sonnet / Haiku requests resolve to their model-specific backend, with `MODEL` as fallback
|
||||
- **Request optimization**: 5 categories of trivial requests (quota probes, title generation, prefix detection, suggestions, filepath extraction) are intercepted and responded to locally without using API quota
|
||||
- **Format translation**: Requests are translated from Anthropic format to the provider's OpenAI-compatible format and streamed back
|
||||
- **Thinking tokens**: `<think>` tags and `reasoning_content` fields are converted into native Claude thinking blocks
|
||||
- **Thinking tokens**: `<think>` tags and `reasoning_content` fields are converted into native Claude thinking blocks when `ENABLE_THINKING=true`
|
||||
|
||||
The proxy also exposes Claude-compatible probe routes: `GET /v1/models`, `POST /v1/messages`, `POST /v1/messages/count_tokens`, plus `HEAD`/`OPTIONS` support for the common probe endpoints.
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -447,7 +453,7 @@ Configure via `WHISPER_DEVICE` (`cpu` | `cuda` | `nvidia_nim`) and `WHISPER_MODE
|
|||
| `MODEL_SONNET` | Model for Claude Sonnet requests (falls back to `MODEL`) | `open_router/arcee-ai/trinity-large-preview:free` |
|
||||
| `MODEL_HAIKU` | Model for Claude Haiku requests (falls back to `MODEL`) | `open_router/stepfun/step-3.5-flash:free` |
|
||||
| `NVIDIA_NIM_API_KEY` | NVIDIA API key | required for NIM |
|
||||
| `NIM_ENABLE_THINKING` | Send `chat_template_kwargs` + `reasoning_budget` on NIM requests. Enable for thinking models (kimi, nemotron); leave `false` for others (e.g. Mistral) | `false` |
|
||||
| `ENABLE_THINKING` | Global switch for provider reasoning requests and Claude thinking blocks. Set `false` to hide thinking across all providers. | `true` |
|
||||
| `OPENROUTER_API_KEY` | OpenRouter API key | required for OpenRouter |
|
||||
| `LM_STUDIO_BASE_URL` | LM Studio server URL | `http://localhost:1234/v1` |
|
||||
| `LLAMACPP_BASE_URL` | llama.cpp server URL | `http://localhost:8080/v1` |
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
|
|||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
)
|
||||
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
if provider_type == "open_router":
|
||||
|
|
@ -56,6 +57,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
|
|||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
)
|
||||
return OpenRouterProvider(config)
|
||||
if provider_type == "lmstudio":
|
||||
|
|
@ -68,6 +70,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
|
|||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
)
|
||||
return LMStudioProvider(config)
|
||||
if provider_type == "llamacpp":
|
||||
|
|
@ -80,6 +83,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
|
|||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
)
|
||||
return LlamaCppProvider(config)
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,13 @@ from .anthropic import (
|
|||
TokenCountRequest,
|
||||
Tool,
|
||||
)
|
||||
from .responses import MessagesResponse, TokenCountResponse, Usage
|
||||
from .responses import (
|
||||
MessagesResponse,
|
||||
ModelResponse,
|
||||
ModelsListResponse,
|
||||
TokenCountResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContentBlockImage",
|
||||
|
|
@ -25,6 +31,8 @@ __all__ = [
|
|||
"Message",
|
||||
"MessagesRequest",
|
||||
"MessagesResponse",
|
||||
"ModelResponse",
|
||||
"ModelsListResponse",
|
||||
"Role",
|
||||
"SystemContent",
|
||||
"ThinkingConfig",
|
||||
|
|
|
|||
|
|
@ -11,6 +11,20 @@ class TokenCountResponse(BaseModel):
|
|||
input_tokens: int
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
created_at: str
|
||||
display_name: str
|
||||
id: str
|
||||
type: Literal["model"] = "model"
|
||||
|
||||
|
||||
class ModelsListResponse(BaseModel):
|
||||
data: list[ModelResponse]
|
||||
first_id: str | None
|
||||
has_more: bool
|
||||
last_id: str | None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import traceback
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -13,13 +13,57 @@ from providers.exceptions import InvalidRequestError, ProviderError
|
|||
|
||||
from .dependencies import get_provider_for_type, get_settings, require_api_key
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .models.responses import ModelResponse, ModelsListResponse, TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
from .request_utils import get_token_count
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
SUPPORTED_CLAUDE_MODELS = [
|
||||
ModelResponse(
|
||||
id="claude-opus-4-20250514",
|
||||
display_name="Claude Opus 4",
|
||||
created_at="2025-05-14T00:00:00Z",
|
||||
),
|
||||
ModelResponse(
|
||||
id="claude-sonnet-4-20250514",
|
||||
display_name="Claude Sonnet 4",
|
||||
created_at="2025-05-14T00:00:00Z",
|
||||
),
|
||||
ModelResponse(
|
||||
id="claude-haiku-4-20250514",
|
||||
display_name="Claude Haiku 4",
|
||||
created_at="2025-05-14T00:00:00Z",
|
||||
),
|
||||
ModelResponse(
|
||||
id="claude-3-opus-20240229",
|
||||
display_name="Claude 3 Opus",
|
||||
created_at="2024-02-29T00:00:00Z",
|
||||
),
|
||||
ModelResponse(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
created_at="2024-10-22T00:00:00Z",
|
||||
),
|
||||
ModelResponse(
|
||||
id="claude-3-haiku-20240307",
|
||||
display_name="Claude 3 Haiku",
|
||||
created_at="2024-03-07T00:00:00Z",
|
||||
),
|
||||
ModelResponse(
|
||||
id="claude-3-5-haiku-20241022",
|
||||
display_name="Claude 3.5 Haiku",
|
||||
created_at="2024-10-22T00:00:00Z",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _probe_response(allow: str) -> Response:
|
||||
"""Return an empty success response for compatibility probes."""
|
||||
return Response(status_code=204, headers={"Allow": allow})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Routes
|
||||
# =============================================================================
|
||||
|
|
@ -83,6 +127,12 @@ async def create_message(
|
|||
) from e
|
||||
|
||||
|
||||
@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
|
||||
async def probe_messages(_auth=Depends(require_api_key)):
|
||||
"""Respond to Claude compatibility probes for the messages endpoint."""
|
||||
return _probe_response("POST, HEAD, OPTIONS")
|
||||
|
||||
|
||||
@router.post("/v1/messages/count_tokens")
|
||||
async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_api_key)):
|
||||
"""Count tokens for a request."""
|
||||
|
|
@ -112,6 +162,12 @@ async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_ap
|
|||
) from e
|
||||
|
||||
|
||||
@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
|
||||
async def probe_count_tokens(_auth=Depends(require_api_key)):
|
||||
"""Respond to Claude compatibility probes for the token count endpoint."""
|
||||
return _probe_response("POST, HEAD, OPTIONS")
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root(
|
||||
settings: Settings = Depends(get_settings), _auth=Depends(require_api_key)
|
||||
|
|
@ -124,12 +180,35 @@ async def root(
|
|||
}
|
||||
|
||||
|
||||
@router.api_route("/", methods=["HEAD", "OPTIONS"])
|
||||
async def probe_root(_auth=Depends(require_api_key)):
|
||||
"""Respond to compatibility probes for the root endpoint."""
|
||||
return _probe_response("GET, HEAD, OPTIONS")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@router.api_route("/health", methods=["HEAD", "OPTIONS"])
|
||||
async def probe_health():
|
||||
"""Respond to compatibility probes for the health endpoint."""
|
||||
return _probe_response("GET, HEAD, OPTIONS")
|
||||
|
||||
|
||||
@router.get("/v1/models", response_model=ModelsListResponse)
|
||||
async def list_models(_auth=Depends(require_api_key)):
|
||||
"""List the Claude model ids this proxy advertises for compatibility."""
|
||||
return ModelsListResponse(
|
||||
data=SUPPORTED_CLAUDE_MODELS,
|
||||
first_id=SUPPORTED_CLAUDE_MODELS[0].id if SUPPORTED_CLAUDE_MODELS else None,
|
||||
has_more=False,
|
||||
last_id=SUPPORTED_CLAUDE_MODELS[-1].id if SUPPORTED_CLAUDE_MODELS else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
async def stop_cli(request: Request, _auth=Depends(require_api_key)):
|
||||
"""Stop all CLI sessions and pending tasks."""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ class NimSettings(BaseModel):
|
|||
|
||||
parallel_tool_calls: bool = True
|
||||
ignore_eos: bool = False
|
||||
enable_thinking: bool = False
|
||||
|
||||
min_tokens: int = Field(0, ge=0)
|
||||
chat_template: str | None = None
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
|
@ -21,6 +23,58 @@ def _env_files() -> tuple[Path, ...]:
|
|||
return tuple(files)
|
||||
|
||||
|
||||
def _configured_env_files(model_config: Mapping[str, Any]) -> tuple[Path, ...]:
|
||||
"""Return the currently configured env files for Settings."""
|
||||
configured = model_config.get("env_file")
|
||||
if configured is None:
|
||||
return ()
|
||||
if isinstance(configured, (str, Path)):
|
||||
return (Path(configured),)
|
||||
return tuple(Path(item) for item in configured)
|
||||
|
||||
|
||||
def _env_file_contains_key(path: Path, key: str) -> bool:
|
||||
"""Check whether a dotenv-style file defines the given key."""
|
||||
if not path.is_file():
|
||||
return False
|
||||
|
||||
try:
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
if stripped.startswith("export "):
|
||||
stripped = stripped[7:].lstrip()
|
||||
name, sep, _value = stripped.partition("=")
|
||||
if sep and name.strip() == key:
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _removed_env_var_message(model_config: Mapping[str, Any]) -> str | None:
|
||||
"""Return a migration error for removed env vars, if present."""
|
||||
removed_key = "NIM_ENABLE_THINKING"
|
||||
replacement = "ENABLE_THINKING"
|
||||
|
||||
if removed_key in os.environ:
|
||||
return (
|
||||
f"{removed_key} has been removed in this release. "
|
||||
f"Rename it to {replacement}."
|
||||
)
|
||||
|
||||
for env_file in _configured_env_files(model_config):
|
||||
if _env_file_contains_key(env_file, removed_key):
|
||||
return (
|
||||
f"{removed_key} has been removed in this release. "
|
||||
f"Rename it to {replacement}. Found in {env_file}."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
|
|
@ -67,6 +121,7 @@ class Settings(BaseSettings):
|
|||
provider_max_concurrency: int = Field(
|
||||
default=5, validation_alias="PROVIDER_MAX_CONCURRENCY"
|
||||
)
|
||||
enable_thinking: bool = Field(default=True, validation_alias="ENABLE_THINKING")
|
||||
|
||||
# ==================== HTTP Client Timeouts ====================
|
||||
http_read_timeout: float = Field(
|
||||
|
|
@ -90,9 +145,6 @@ class Settings(BaseSettings):
|
|||
|
||||
# ==================== NIM Settings ====================
|
||||
nim: NimSettings = Field(default_factory=NimSettings)
|
||||
nim_enable_thinking: bool = Field(
|
||||
default=False, validation_alias="NIM_ENABLE_THINKING"
|
||||
)
|
||||
|
||||
# ==================== Voice Note Transcription ====================
|
||||
voice_note_enabled: bool = Field(
|
||||
|
|
@ -131,6 +183,14 @@ class Settings(BaseSettings):
|
|||
default="", validation_alias="ANTHROPIC_AUTH_TOKEN"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def reject_removed_env_vars(cls, data: Any) -> Any:
|
||||
"""Fail fast when removed environment variables are still configured."""
|
||||
if message := _removed_env_var_message(cls.model_config):
|
||||
raise ValueError(message)
|
||||
return data
|
||||
|
||||
# Handle empty strings for optional string fields
|
||||
@field_validator(
|
||||
"telegram_bot_token",
|
||||
|
|
@ -140,7 +200,7 @@ class Settings(BaseSettings):
|
|||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def parse_optional_str(cls, v):
|
||||
def parse_optional_str(cls, v: Any) -> Any:
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
|
@ -174,13 +234,6 @@ class Settings(BaseSettings):
|
|||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _inject_nim_thinking(self) -> Settings:
|
||||
self.nim = self.nim.model_copy(
|
||||
update={"enable_thinking": self.nim_enable_thinking}
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_nvidia_nim_api_key(self) -> Settings:
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class ProviderConfig(BaseModel):
|
|||
http_read_timeout: float = 300.0
|
||||
http_write_timeout: float = 10.0
|
||||
http_connect_timeout: float = 2.0
|
||||
enable_thinking: bool = True
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
|
|
@ -30,6 +31,16 @@ class BaseProvider(ABC):
|
|||
def __init__(self, config: ProviderConfig):
|
||||
self._config = config
|
||||
|
||||
def _is_thinking_enabled(self, request: Any) -> bool:
|
||||
"""Return whether thinking should be enabled for this request."""
|
||||
thinking = getattr(request, "thinking", None)
|
||||
request_enabled = (
|
||||
thinking.enabled
|
||||
if thinking is not None and hasattr(thinking, "enabled")
|
||||
else True
|
||||
)
|
||||
return self._config.enable_thinking and request_enabled
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Release any resources held by this provider."""
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class AnthropicToOpenAIConverter:
|
|||
def convert_messages(
|
||||
messages: list[Any],
|
||||
*,
|
||||
include_thinking: bool = True,
|
||||
include_reasoning_for_openrouter: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert a list of Anthropic messages to OpenAI format.
|
||||
|
|
@ -46,6 +47,7 @@ class AnthropicToOpenAIConverter:
|
|||
result.extend(
|
||||
AnthropicToOpenAIConverter._convert_assistant_message(
|
||||
content,
|
||||
include_thinking=include_thinking,
|
||||
include_reasoning_for_openrouter=include_reasoning_for_openrouter,
|
||||
)
|
||||
)
|
||||
|
|
@ -62,6 +64,7 @@ class AnthropicToOpenAIConverter:
|
|||
def _convert_assistant_message(
|
||||
content: list[Any],
|
||||
*,
|
||||
include_thinking: bool = True,
|
||||
include_reasoning_for_openrouter: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert assistant message blocks, preserving interleaved thinking+text order."""
|
||||
|
|
@ -75,6 +78,8 @@ class AnthropicToOpenAIConverter:
|
|||
if block_type == "text":
|
||||
content_parts.append(get_block_attr(block, "text", ""))
|
||||
elif block_type == "thinking":
|
||||
if not include_thinking:
|
||||
continue
|
||||
thinking = get_block_attr(block, "thinking", "")
|
||||
content_parts.append(f"<think>\n{thinking}\n</think>")
|
||||
if include_reasoning_for_openrouter:
|
||||
|
|
@ -184,6 +189,7 @@ def build_base_request_body(
|
|||
request_data: Any,
|
||||
*,
|
||||
default_max_tokens: int | None = None,
|
||||
include_thinking: bool = True,
|
||||
include_reasoning_for_openrouter: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the common parts of an OpenAI-format request body.
|
||||
|
|
@ -196,6 +202,7 @@ def build_base_request_body(
|
|||
|
||||
messages = AnthropicToOpenAIConverter.convert_messages(
|
||||
request_data.messages,
|
||||
include_thinking=include_thinking,
|
||||
include_reasoning_for_openrouter=include_reasoning_for_openrouter,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class LlamaCppProvider(BaseProvider):
|
|||
"""Stream response natively via Llama.cpp's Anthropic-compatible endpoint."""
|
||||
tag = self._provider_name
|
||||
req_tag = f" request_id={request_id}" if request_id else ""
|
||||
thinking_enabled = self._is_thinking_enabled(request)
|
||||
|
||||
# Dump the Anthropic Pydantic model directly into a dict
|
||||
body = request.model_dump(exclude_none=True)
|
||||
|
|
@ -68,7 +69,11 @@ class LlamaCppProvider(BaseProvider):
|
|||
# Translate internal ThinkingConfig to Anthropic API schema
|
||||
if "thinking" in body:
|
||||
thinking_cfg = body.pop("thinking")
|
||||
if isinstance(thinking_cfg, dict) and thinking_cfg.get("enabled"):
|
||||
if (
|
||||
thinking_enabled
|
||||
and isinstance(thinking_cfg, dict)
|
||||
and thinking_cfg.get("enabled")
|
||||
):
|
||||
# Anthropic API requires a budget_tokens value when enabled
|
||||
body["thinking"] = {"type": "enabled"}
|
||||
|
||||
|
|
@ -126,9 +131,15 @@ class LlamaCppProvider(BaseProvider):
|
|||
except Exception as e:
|
||||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||
mapped_e = map_error(e)
|
||||
error_message = get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
)
|
||||
if getattr(mapped_e, "status_code", None) == 405:
|
||||
error_message = (
|
||||
f"Upstream provider {tag} rejected the request method "
|
||||
"or endpoint (HTTP 405)."
|
||||
)
|
||||
else:
|
||||
error_message = get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
)
|
||||
if request_id:
|
||||
error_message += f"\nRequest ID: {request_id}"
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class LMStudioProvider(BaseProvider):
|
|||
"""Stream response natively via LM Studio's Anthropic-compatible endpoint."""
|
||||
tag = self._provider_name
|
||||
req_tag = f" request_id={request_id}" if request_id else ""
|
||||
thinking_enabled = self._is_thinking_enabled(request)
|
||||
|
||||
# Dump the Anthropic Pydantic model directly into a dict
|
||||
body = request.model_dump(exclude_none=True)
|
||||
|
|
@ -68,7 +69,11 @@ class LMStudioProvider(BaseProvider):
|
|||
# Translate internal ThinkingConfig to Anthropic API schema
|
||||
if "thinking" in body:
|
||||
thinking_cfg = body.pop("thinking")
|
||||
if isinstance(thinking_cfg, dict) and thinking_cfg.get("enabled"):
|
||||
if (
|
||||
thinking_enabled
|
||||
and isinstance(thinking_cfg, dict)
|
||||
and thinking_cfg.get("enabled")
|
||||
):
|
||||
# Anthropic API requires a budget_tokens value when enabled
|
||||
body["thinking"] = {"type": "enabled"}
|
||||
|
||||
|
|
@ -126,9 +131,15 @@ class LMStudioProvider(BaseProvider):
|
|||
except Exception as e:
|
||||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||
mapped_e = map_error(e)
|
||||
error_message = get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
)
|
||||
if getattr(mapped_e, "status_code", None) == 405:
|
||||
error_message = (
|
||||
f"Upstream provider {tag} rejected the request method "
|
||||
"or endpoint (HTTP 405)."
|
||||
)
|
||||
else:
|
||||
error_message = get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
)
|
||||
if request_id:
|
||||
error_message += f"\nRequest ID: {request_id}"
|
||||
|
||||
|
|
|
|||
|
|
@ -25,4 +25,8 @@ class NvidiaNimProvider(OpenAICompatibleProvider):
|
|||
|
||||
def _build_request_body(self, request: Any) -> dict:
|
||||
"""Internal helper for tests and shared building."""
|
||||
return build_request_body(request, self._nim_settings)
|
||||
return build_request_body(
|
||||
request,
|
||||
self._nim_settings,
|
||||
thinking_enabled=self._is_thinking_enabled(request),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,14 +21,19 @@ def _set_extra(
|
|||
extra_body[key] = value
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, nim: NimSettings) -> dict:
|
||||
def build_request_body(
|
||||
request_data: Any, nim: NimSettings, *, thinking_enabled: bool
|
||||
) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request."""
|
||||
logger.debug(
|
||||
"NIM_REQUEST: conversion start model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
body = build_base_request_body(request_data)
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
include_thinking=thinking_enabled,
|
||||
)
|
||||
|
||||
# NIM-specific max_tokens: cap against nim.max_tokens
|
||||
max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None)
|
||||
|
|
@ -63,7 +68,7 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
|
|||
if request_extra:
|
||||
extra_body.update(request_extra)
|
||||
|
||||
if nim.enable_thinking:
|
||||
if thinking_enabled:
|
||||
extra_body.setdefault(
|
||||
"chat_template_kwargs", {"thinking": True, "enable_thinking": True}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,10 +25,17 @@ class OpenRouterProvider(OpenAICompatibleProvider):
|
|||
|
||||
def _build_request_body(self, request: Any) -> dict:
|
||||
"""Internal helper for tests and shared building."""
|
||||
return build_request_body(request)
|
||||
return build_request_body(
|
||||
request,
|
||||
thinking_enabled=self._is_thinking_enabled(request),
|
||||
)
|
||||
|
||||
def _handle_extra_reasoning(self, delta: Any, sse: SSEBuilder) -> Iterator[str]:
|
||||
def _handle_extra_reasoning(
|
||||
self, delta: Any, sse: SSEBuilder, *, thinking_enabled: bool
|
||||
) -> Iterator[str]:
|
||||
"""Handle reasoning_details for StepFun models."""
|
||||
if not thinking_enabled:
|
||||
return
|
||||
reasoning_details = getattr(delta, "reasoning_details", None)
|
||||
if reasoning_details and isinstance(reasoning_details, list):
|
||||
for item in reasoning_details:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from providers.common.message_converter import build_base_request_body
|
|||
OPENROUTER_DEFAULT_MAX_TOKENS = 81920
|
||||
|
||||
|
||||
def build_request_body(request_data: Any) -> dict:
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
|
||||
logger.debug(
|
||||
"OPENROUTER_REQUEST: conversion start model={} msgs={}",
|
||||
|
|
@ -18,8 +18,9 @@ def build_request_body(request_data: Any) -> dict:
|
|||
)
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
include_thinking=thinking_enabled,
|
||||
default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS,
|
||||
include_reasoning_for_openrouter=True,
|
||||
include_reasoning_for_openrouter=thinking_enabled,
|
||||
)
|
||||
|
||||
# OpenRouter reasoning: extra_body={"reasoning": {"enabled": True}}
|
||||
|
|
@ -28,10 +29,6 @@ def build_request_body(request_data: Any) -> dict:
|
|||
if request_extra:
|
||||
extra_body.update(request_extra)
|
||||
|
||||
thinking = getattr(request_data, "thinking", None)
|
||||
thinking_enabled = (
|
||||
thinking.enabled if thinking and hasattr(thinking, "enabled") else True
|
||||
)
|
||||
if thinking_enabled:
|
||||
extra_body.setdefault("reasoning", {"enabled": True})
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
def _build_request_body(self, request: Any) -> dict:
|
||||
"""Build request body. Must be implemented by subclasses."""
|
||||
|
||||
def _handle_extra_reasoning(self, delta: Any, sse: SSEBuilder) -> Iterator[str]:
|
||||
def _handle_extra_reasoning(
|
||||
self, delta: Any, sse: SSEBuilder, *, thinking_enabled: bool
|
||||
) -> Iterator[str]:
|
||||
"""Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details)."""
|
||||
return iter(())
|
||||
|
||||
|
|
@ -151,6 +153,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
|
||||
think_parser = ThinkTagParser()
|
||||
heuristic_parser = HeuristicToolParser()
|
||||
thinking_enabled = self._is_thinking_enabled(request)
|
||||
|
||||
finish_reason = None
|
||||
usage_info = None
|
||||
|
|
@ -180,19 +183,25 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
|
||||
# Handle reasoning_content (OpenAI extended format)
|
||||
reasoning = getattr(delta, "reasoning_content", None)
|
||||
if reasoning:
|
||||
if thinking_enabled and reasoning:
|
||||
for event in sse.ensure_thinking_block():
|
||||
yield event
|
||||
yield sse.emit_thinking_delta(reasoning)
|
||||
|
||||
# Provider-specific extra reasoning (e.g. OpenRouter reasoning_details)
|
||||
for event in self._handle_extra_reasoning(delta, sse):
|
||||
for event in self._handle_extra_reasoning(
|
||||
delta,
|
||||
sse,
|
||||
thinking_enabled=thinking_enabled,
|
||||
):
|
||||
yield event
|
||||
|
||||
# Handle text content
|
||||
if delta.content:
|
||||
for part in think_parser.feed(delta.content):
|
||||
if part.type == ContentType.THINKING:
|
||||
if not thinking_enabled:
|
||||
continue
|
||||
for event in sse.ensure_thinking_block():
|
||||
yield event
|
||||
yield sse.emit_thinking_delta(part.content)
|
||||
|
|
@ -248,12 +257,16 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||
mapped_e = map_error(e)
|
||||
error_occurred = True
|
||||
error_message = append_request_id(
|
||||
get_user_facing_error_message(
|
||||
if getattr(mapped_e, "status_code", None) == 405:
|
||||
base_message = (
|
||||
f"Upstream provider {tag} rejected the request method "
|
||||
"or endpoint (HTTP 405)."
|
||||
)
|
||||
else:
|
||||
base_message = get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
)
|
||||
error_message = append_request_id(base_message, request_id)
|
||||
logger.info(
|
||||
"{}_STREAM: Emitting SSE error event for {}{}",
|
||||
tag,
|
||||
|
|
@ -269,10 +282,13 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
remaining = think_parser.flush()
|
||||
if remaining:
|
||||
if remaining.type == ContentType.THINKING:
|
||||
for event in sse.ensure_thinking_block():
|
||||
yield event
|
||||
yield sse.emit_thinking_delta(remaining.content)
|
||||
else:
|
||||
if not thinking_enabled:
|
||||
remaining = None
|
||||
else:
|
||||
for event in sse.ensure_thinking_block():
|
||||
yield event
|
||||
yield sse.emit_thinking_delta(remaining.content)
|
||||
if remaining and remaining.type == ContentType.TEXT:
|
||||
for event in sse.ensure_text_block():
|
||||
yield event
|
||||
yield sse.emit_text_delta(remaining.content)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,34 @@ def test_health():
|
|||
assert response.json()["status"] == "healthy"
|
||||
|
||||
|
||||
def test_models_list():
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
ids = [item["id"] for item in data["data"]]
|
||||
assert "claude-sonnet-4-20250514" in ids
|
||||
assert data["first_id"] == ids[0]
|
||||
assert data["last_id"] == ids[-1]
|
||||
|
||||
|
||||
def test_probe_endpoints_return_204_with_allow_headers():
|
||||
responses = [
|
||||
client.head("/"),
|
||||
client.options("/"),
|
||||
client.head("/health"),
|
||||
client.options("/health"),
|
||||
client.head("/v1/messages"),
|
||||
client.options("/v1/messages"),
|
||||
client.head("/v1/messages/count_tokens"),
|
||||
client.options("/v1/messages/count_tokens"),
|
||||
]
|
||||
|
||||
for response in responses:
|
||||
assert response.status_code == 204
|
||||
assert "Allow" in response.headers
|
||||
|
||||
|
||||
def test_create_message_stream():
|
||||
"""Create message returns streaming response."""
|
||||
payload = {
|
||||
|
|
|
|||
|
|
@ -55,3 +55,19 @@ def test_anthropic_auth_token_accepts_bearer_authorization():
|
|||
assert r.json()["input_tokens"] == 2
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_anthropic_auth_token_applies_to_models_endpoint():
|
||||
client = TestClient(app)
|
||||
settings = Settings()
|
||||
settings.anthropic_auth_token = "models-token"
|
||||
app.dependency_overrides[get_settings] = lambda: settings
|
||||
|
||||
r = client.get("/v1/models")
|
||||
assert r.status_code == 401
|
||||
|
||||
r = client.get("/v1/models", headers={"X-API-Key": "models-token"})
|
||||
assert r.status_code == 200
|
||||
assert "data" in r.json()
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class TestSettings:
|
|||
assert isinstance(settings.provider_rate_window, int)
|
||||
assert isinstance(settings.nim.temperature, float)
|
||||
assert isinstance(settings.fast_prefix_detection, bool)
|
||||
assert isinstance(settings.enable_thinking, bool)
|
||||
|
||||
def test_get_settings_cached(self):
|
||||
"""Test get_settings returns cached instance."""
|
||||
|
|
@ -104,6 +105,22 @@ class TestSettings:
|
|||
settings = Settings()
|
||||
assert settings.http_connect_timeout == 5.0
|
||||
|
||||
def test_enable_thinking_from_env(self, monkeypatch):
|
||||
"""ENABLE_THINKING env var is loaded into settings."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("ENABLE_THINKING", "false")
|
||||
settings = Settings()
|
||||
assert settings.enable_thinking is False
|
||||
|
||||
def test_removed_nim_enable_thinking_raises(self, monkeypatch):
|
||||
"""NIM_ENABLE_THINKING now fails fast with a migration message."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("NIM_ENABLE_THINKING", "false")
|
||||
with pytest.raises(ValidationError, match="Rename it to ENABLE_THINKING"):
|
||||
Settings()
|
||||
|
||||
|
||||
# --- NimSettings Validation Tests ---
|
||||
class TestNimSettingsValidBounds:
|
||||
|
|
@ -228,6 +245,13 @@ class TestNimSettingsValidators:
|
|||
with pytest.raises(ValidationError):
|
||||
NimSettings(**cast(Any, {"unknown_field": "value"}))
|
||||
|
||||
def test_enable_thinking_field_removed(self):
|
||||
"""NimSettings no longer accepts the removed thinking toggle."""
|
||||
from typing import Any, cast
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
NimSettings(**cast(Any, {"enable_thinking": True}))
|
||||
|
||||
|
||||
class TestSettingsOptionalStr:
|
||||
"""Test Settings parse_optional_str validator."""
|
||||
|
|
|
|||
|
|
@ -2,9 +2,13 @@ import asyncio
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from config.settings import Settings
|
||||
|
||||
# Set mock environment BEFORE any imports that use Settings
|
||||
os.environ.setdefault("NVIDIA_NIM_API_KEY", "test_key")
|
||||
os.environ.setdefault("MODEL", "nvidia_nim/test-model")
|
||||
|
|
@ -13,26 +17,12 @@ os.environ["PTB_TIMEDELTA"] = "1"
|
|||
# (tests expect endpoints to be unauthenticated by default)
|
||||
os.environ["ANTHROPIC_AUTH_TOKEN"] = ""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from config.nim import NimSettings
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.platforms.base import (
|
||||
CLISession,
|
||||
MessagingPlatform,
|
||||
SessionManagerInterface,
|
||||
)
|
||||
from messaging.session import SessionStore
|
||||
from providers.base import ProviderConfig
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
Settings.model_config = {**Settings.model_config, "env_file": None}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_from_dotenv(monkeypatch):
|
||||
"""Prevent Pydantic BaseSettings from reading the .env file during tests."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setattr(
|
||||
Settings, "model_config", {**Settings.model_config, "env_file": None}
|
||||
)
|
||||
|
|
@ -40,6 +30,8 @@ def _isolate_from_dotenv(monkeypatch):
|
|||
|
||||
@pytest.fixture
|
||||
def provider_config():
|
||||
from providers.base import ProviderConfig
|
||||
|
||||
return ProviderConfig(
|
||||
api_key="test_key",
|
||||
base_url="https://test.api.nvidia.com/v1",
|
||||
|
|
@ -50,6 +42,9 @@ def provider_config():
|
|||
|
||||
@pytest.fixture
|
||||
def nim_provider(provider_config):
|
||||
from config.nim import NimSettings
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
return NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
|
||||
|
||||
|
|
@ -62,6 +57,7 @@ def open_router_provider(provider_config):
|
|||
|
||||
@pytest.fixture
|
||||
def lmstudio_provider(provider_config):
|
||||
from providers.base import ProviderConfig
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
lmstudio_config = ProviderConfig(
|
||||
|
|
@ -75,6 +71,7 @@ def lmstudio_provider(provider_config):
|
|||
|
||||
@pytest.fixture
|
||||
def llamacpp_provider(provider_config):
|
||||
from providers.base import ProviderConfig
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
|
||||
llamacpp_config = ProviderConfig(
|
||||
|
|
@ -88,6 +85,8 @@ def llamacpp_provider(provider_config):
|
|||
|
||||
@pytest.fixture
|
||||
def mock_cli_session():
|
||||
from messaging.platforms.base import CLISession
|
||||
|
||||
session = MagicMock(spec=CLISession)
|
||||
session.start_task = MagicMock() # This will return an async generator
|
||||
session.is_busy = False
|
||||
|
|
@ -96,6 +95,8 @@ def mock_cli_session():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_cli_manager():
|
||||
from messaging.platforms.base import SessionManagerInterface
|
||||
|
||||
manager = MagicMock(spec=SessionManagerInterface)
|
||||
manager.get_or_create_session = AsyncMock()
|
||||
manager.register_real_session_id = AsyncMock(return_value=True)
|
||||
|
|
@ -107,6 +108,8 @@ def mock_cli_manager():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_platform():
|
||||
from messaging.platforms.base import MessagingPlatform
|
||||
|
||||
platform = MagicMock(spec=MessagingPlatform)
|
||||
platform.send_message = AsyncMock(return_value="msg_123")
|
||||
platform.edit_message = AsyncMock()
|
||||
|
|
@ -127,6 +130,8 @@ def mock_platform():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_session_store():
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = MagicMock(spec=SessionStore)
|
||||
store.save_tree = MagicMock()
|
||||
store.get_tree = MagicMock(return_value=None)
|
||||
|
|
@ -156,6 +161,8 @@ def incoming_message_factory():
|
|||
)
|
||||
|
||||
def _create(**kwargs):
|
||||
from messaging.models import IncomingMessage
|
||||
|
||||
defaults: dict[str, Any] = {
|
||||
"text": "hello",
|
||||
"chat_id": "chat_1",
|
||||
|
|
|
|||
|
|
@ -199,6 +199,24 @@ def test_convert_assistant_message_thinking_include_reasoning_for_openrouter():
|
|||
assert "<think>" in result[0]["content"]
|
||||
|
||||
|
||||
def test_convert_assistant_message_thinking_removed_when_disabled():
|
||||
content = [
|
||||
MockBlock(type="thinking", thinking="I need to calculate this."),
|
||||
MockBlock(type="text", text="The answer is 4."),
|
||||
]
|
||||
messages = [MockMessage("assistant", content)]
|
||||
result = AnthropicToOpenAIConverter.convert_messages(
|
||||
messages,
|
||||
include_thinking=False,
|
||||
include_reasoning_for_openrouter=True,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "reasoning_content" not in result[0]
|
||||
assert "<think>" not in result[0]["content"]
|
||||
assert result[0]["content"] == "The answer is 4."
|
||||
|
||||
|
||||
def test_convert_assistant_message_tool_use():
|
||||
content = [
|
||||
MockBlock(type="text", text="I will call the tool."),
|
||||
|
|
|
|||
|
|
@ -110,6 +110,37 @@ def test_init_base_url_strips_trailing_slash():
|
|||
assert provider._base_url == "http://localhost:8080/v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_omits_thinking_when_globally_disabled(llamacpp_config):
|
||||
provider = LlamaCppProvider(
|
||||
llamacpp_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def empty_aiter():
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
mock_response.aiter_lines = empty_aiter
|
||||
|
||||
with (
|
||||
patch.object(provider._client, "build_request") as mock_build,
|
||||
patch.object(
|
||||
provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
[e async for e in provider.stream_response(req)]
|
||||
|
||||
_, kwargs = mock_build.call_args
|
||||
assert "thinking" not in kwargs["json"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response(llamacpp_provider):
|
||||
"""Test streaming native Anthropic response."""
|
||||
|
|
@ -254,3 +285,38 @@ async def test_stream_network_error(llamacpp_provider):
|
|||
assert events[0].startswith("event: error\ndata: {")
|
||||
assert "Connection refused" in events[0]
|
||||
assert "TEST_ID2" in events[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_error_405_mentions_upstream_provider(llamacpp_provider):
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_response.aread = AsyncMock(return_value=b"Method Not Allowed")
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Method Not Allowed", request=MagicMock(), response=mock_response
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
llamacpp_provider._client, "build_request", return_value=MagicMock()
|
||||
),
|
||||
patch.object(
|
||||
llamacpp_provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
events = [
|
||||
e async for e in llamacpp_provider.stream_response(req, request_id="REQ405")
|
||||
]
|
||||
|
||||
assert (
|
||||
"Upstream provider LLAMACPP rejected the request method or endpoint (HTTP 405)."
|
||||
in events[0]
|
||||
)
|
||||
assert "REQ405" in events[0]
|
||||
|
|
|
|||
|
|
@ -110,6 +110,68 @@ def test_init_base_url_strips_trailing_slash():
|
|||
assert provider._base_url == "http://localhost:1234/v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_omits_thinking_when_globally_disabled(lmstudio_config):
|
||||
provider = LMStudioProvider(
|
||||
lmstudio_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def empty_aiter():
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
mock_response.aiter_lines = empty_aiter
|
||||
|
||||
with (
|
||||
patch.object(provider._client, "build_request") as mock_build,
|
||||
patch.object(
|
||||
provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
[e async for e in provider.stream_response(req)]
|
||||
|
||||
_, kwargs = mock_build.call_args
|
||||
assert "thinking" not in kwargs["json"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_omits_thinking_when_request_disables_it(
|
||||
lmstudio_provider,
|
||||
):
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def empty_aiter():
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
mock_response.aiter_lines = empty_aiter
|
||||
|
||||
with (
|
||||
patch.object(lmstudio_provider._client, "build_request") as mock_build,
|
||||
patch.object(
|
||||
lmstudio_provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
[e async for e in lmstudio_provider.stream_response(req)]
|
||||
|
||||
_, kwargs = mock_build.call_args
|
||||
assert "thinking" not in kwargs["json"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response(lmstudio_provider):
|
||||
"""Test streaming native Anthropic response."""
|
||||
|
|
@ -254,3 +316,38 @@ async def test_stream_network_error(lmstudio_provider):
|
|||
assert events[0].startswith("event: error\ndata: {")
|
||||
assert "Connection refused" in events[0]
|
||||
assert "TEST_ID2" in events[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider):
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_response.aread = AsyncMock(return_value=b"Method Not Allowed")
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Method Not Allowed", request=MagicMock(), response=mock_response
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
lmstudio_provider._client, "build_request", return_value=MagicMock()
|
||||
),
|
||||
patch.object(
|
||||
lmstudio_provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
events = [
|
||||
e async for e in lmstudio_provider.stream_response(req, request_id="REQ405")
|
||||
]
|
||||
|
||||
assert (
|
||||
"Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)."
|
||||
in events[0]
|
||||
)
|
||||
assert "REQ405" in events[0]
|
||||
|
|
|
|||
|
|
@ -91,9 +91,7 @@ async def test_build_request_body(provider_config):
|
|||
"""Test request body construction."""
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config, nim_settings=NimSettings(enable_thinking=True)
|
||||
)
|
||||
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
|
|
@ -110,6 +108,40 @@ async def test_build_request_body(provider_config):
|
|||
assert body["extra_body"]["reasoning_budget"] == body["max_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_request_body_omits_reasoning_when_globally_disabled(
|
||||
provider_config,
|
||||
):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config.model_copy(update={"enable_thinking": False}),
|
||||
nim_settings=NimSettings(),
|
||||
)
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
extra = body.get("extra_body", {})
|
||||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_request_body_omits_reasoning_when_request_disables_thinking(
|
||||
provider_config,
|
||||
):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
extra = body.get("extra_body", {})
|
||||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_text(nim_provider):
|
||||
"""Test streaming text response."""
|
||||
|
|
@ -195,6 +227,44 @@ async def test_stream_response_thinking_reasoning_content(nim_provider):
|
|||
assert found_thinking
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_suppresses_thinking_when_disabled(provider_config):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config.model_copy(update={"enable_thinking": False}),
|
||||
nim_settings=NimSettings(),
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content="<think>secret</think>Answer", reasoning_content="Thinking..."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = None
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [e async for e in provider.stream_response(req)]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking_delta" not in event_text
|
||||
assert "Thinking..." not in event_text
|
||||
assert "secret" not in event_text
|
||||
assert "Answer" in event_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(nim_provider):
|
||||
"""Test streaming tool calls."""
|
||||
|
|
|
|||
|
|
@ -67,21 +67,21 @@ class TestBuildRequestBody:
|
|||
def test_max_tokens_capped_by_nim(self, req):
|
||||
req.max_tokens = 100000
|
||||
nim = NimSettings(max_tokens=4096)
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
assert body["max_tokens"] == 4096
|
||||
|
||||
def test_presence_penalty_included_when_nonzero(self, req):
|
||||
nim = NimSettings(presence_penalty=0.5)
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
assert body["presence_penalty"] == 0.5
|
||||
|
||||
def test_include_stop_str_in_output_not_sent(self, req):
|
||||
body = build_request_body(req, NimSettings())
|
||||
body = build_request_body(req, NimSettings(), thinking_enabled=True)
|
||||
assert "include_stop_str_in_output" not in body.get("extra_body", {})
|
||||
|
||||
def test_parallel_tool_calls_included(self, req):
|
||||
nim = NimSettings(parallel_tool_calls=False)
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
assert body["parallel_tool_calls"] is False
|
||||
|
||||
def test_reasoning_params_in_extra_body(self):
|
||||
|
|
@ -98,8 +98,8 @@ class TestBuildRequestBody:
|
|||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
nim = NimSettings(enable_thinking=True)
|
||||
body = build_request_body(req, nim)
|
||||
nim = NimSettings()
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
extra = body["extra_body"]
|
||||
assert extra["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
|
|
@ -121,8 +121,8 @@ class TestBuildRequestBody:
|
|||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
nim = NimSettings(enable_thinking=False)
|
||||
body = build_request_body(req, nim)
|
||||
nim = NimSettings()
|
||||
body = build_request_body(req, nim, thinking_enabled=False)
|
||||
extra = body.get("extra_body", {})
|
||||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
|
@ -142,7 +142,7 @@ class TestBuildRequestBody:
|
|||
req.top_k = None
|
||||
|
||||
nim = NimSettings()
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=False)
|
||||
extra = body.get("extra_body", {})
|
||||
for param in (
|
||||
"thinking",
|
||||
|
|
@ -152,3 +152,29 @@ class TestBuildRequestBody:
|
|||
"reasoning_effort",
|
||||
):
|
||||
assert param not in extra
|
||||
|
||||
def test_assistant_thinking_blocks_removed_when_disabled(self):
|
||||
req = MagicMock()
|
||||
req.model = "test"
|
||||
req.messages = [
|
||||
MagicMock(
|
||||
role="assistant",
|
||||
content=[
|
||||
MagicMock(type="thinking", thinking="secret"),
|
||||
MagicMock(type="text", text="answer"),
|
||||
],
|
||||
)
|
||||
]
|
||||
req.max_tokens = 100
|
||||
req.system = None
|
||||
req.temperature = None
|
||||
req.top_p = None
|
||||
req.stop_sequences = None
|
||||
req.tools = None
|
||||
req.tool_choice = None
|
||||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
body = build_request_body(req, NimSettings(), thinking_enabled=False)
|
||||
assert "<think>" not in body["messages"][0]["content"]
|
||||
assert "answer" in body["messages"][0]["content"]
|
||||
|
|
|
|||
|
|
@ -105,6 +105,26 @@ def test_build_request_body_has_reasoning_extra(open_router_provider):
|
|||
assert body["extra_body"]["reasoning"]["enabled"] is True
|
||||
|
||||
|
||||
def test_build_request_body_omits_reasoning_when_globally_disabled(open_router_config):
|
||||
provider = OpenRouterProvider(
|
||||
open_router_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
assert "extra_body" not in body or "reasoning" not in body["extra_body"]
|
||||
|
||||
|
||||
def test_build_request_body_omits_reasoning_when_request_disables_thinking(
|
||||
open_router_provider,
|
||||
):
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
body = open_router_provider._build_request_body(req)
|
||||
|
||||
assert "extra_body" not in body or "reasoning" not in body["extra_body"]
|
||||
|
||||
|
||||
def test_build_request_body_base_url_and_model(open_router_provider):
|
||||
"""Base URL and model are correct in provider config."""
|
||||
assert open_router_provider._base_url == "https://openrouter.ai/api/v1"
|
||||
|
|
@ -205,6 +225,44 @@ async def test_stream_response_reasoning_content(open_router_provider):
|
|||
assert found_thinking
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_suppresses_reasoning_when_disabled(open_router_config):
|
||||
provider = OpenRouterProvider(
|
||||
open_router_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content="<think>secret</think>Answer",
|
||||
reasoning_content="Thinking...",
|
||||
reasoning_details=[{"text": "Step 1"}],
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = None
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [e async for e in provider.stream_response(req)]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking_delta" not in event_text
|
||||
assert "Thinking..." not in event_text
|
||||
assert "Step 1" not in event_text
|
||||
assert "secret" not in event_text
|
||||
assert "Answer" in event_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_empty_choices_skipped(open_router_provider):
|
||||
"""Chunks with empty choices are skipped."""
|
||||
|
|
|
|||
|
|
@ -39,6 +39,18 @@ def _make_provider():
|
|||
return NvidiaNimProvider(config, nim_settings=NimSettings())
|
||||
|
||||
|
||||
def _make_provider_with_thinking_enabled(enabled: bool):
|
||||
"""Create a provider instance with thinking explicitly enabled or disabled."""
|
||||
config = ProviderConfig(
|
||||
api_key="test_key",
|
||||
base_url="https://test.api.nvidia.com/v1",
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
enable_thinking=enabled,
|
||||
)
|
||||
return NvidiaNimProvider(config, nim_settings=NimSettings())
|
||||
|
||||
|
||||
def _make_request(model="test-model", stream=True):
|
||||
"""Create a mock request with all fields build_request_body needs."""
|
||||
req = MagicMock()
|
||||
|
|
@ -272,6 +284,76 @@ class TestStreamingExceptionHandling:
|
|||
assert "I think..." in event_text
|
||||
assert "The answer" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_reasoning_content_suppressed_when_disabled(self):
|
||||
"""reasoning deltas are stripped while normal text still streams."""
|
||||
provider = _make_provider_with_thinking_enabled(False)
|
||||
request = _make_request()
|
||||
|
||||
chunk1 = _make_chunk(reasoning_content="I think...")
|
||||
chunk2 = _make_chunk(content="<think>secret</think>The answer")
|
||||
chunk3 = _make_chunk(finish_reason="stop")
|
||||
stream_mock = AsyncStreamMock([chunk1, chunk2, chunk3])
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
),
|
||||
patch.object(
|
||||
provider._global_rate_limiter,
|
||||
"wait_if_blocked",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
events = await _collect_stream(provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking_delta" not in event_text
|
||||
assert "I think..." not in event_text
|
||||
assert "secret" not in event_text
|
||||
assert "The answer" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_upstream_405_mentions_provider_name(self):
|
||||
"""HTTP 405s are surfaced as upstream method/endpoint rejections."""
|
||||
provider = _make_provider()
|
||||
request = _make_request()
|
||||
|
||||
response = httpx.Response(
|
||||
status_code=405,
|
||||
request=httpx.Request("POST", "https://example.com/v1/chat/completions"),
|
||||
)
|
||||
error = httpx.HTTPStatusError(
|
||||
"Method Not Allowed",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=error,
|
||||
):
|
||||
events = [
|
||||
e
|
||||
async for e in provider.stream_response(
|
||||
request,
|
||||
request_id="REQ405",
|
||||
)
|
||||
]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert (
|
||||
"Upstream provider NIM rejected the request method or endpoint (HTTP 405)."
|
||||
in event_text
|
||||
)
|
||||
assert "request_id=REQ405" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_rate_limited_retries_via_execute_with_retry(self):
|
||||
"""When rate limited, execute_with_retry handles retries transparently."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue