diff --git a/.env.example b/.env.example index e1900d9..cd0169f 100644 --- a/.env.example +++ b/.env.example @@ -23,10 +23,10 @@ MODEL_HAIKU="open_router/stepfun/step-3.5-flash:free" MODEL="nvidia_nim/z-ai/glm4.7" -# NIM Settings -# Enable chat_template_kwargs + reasoning_budget for thinking models (kimi, nemotron). -# Leave false for models that don't support it (e.g. Mistral). -NIM_ENABLE_THINKING=false +# Thinking output +# Global switch for provider reasoning requests and Claude thinking blocks. +# Set false to suppress thinking across NIM, OpenRouter, LM Studio, and llama.cpp. +ENABLE_THINKING=true # Provider config @@ -82,4 +82,4 @@ FAST_PREFIX_DETECTION=true ENABLE_NETWORK_PROBE_MOCK=true ENABLE_TITLE_GENERATION_SKIP=true ENABLE_SUGGESTION_MODE_SKIP=true -ENABLE_FILEPATH_EXTRACTION_MOCK=true \ No newline at end of file +ENABLE_FILEPATH_EXTRACTION_MOCK=true diff --git a/README.md b/README.md index 67e7205..db09a91 100644 --- a/README.md +++ b/README.md @@ -80,8 +80,8 @@ MODEL_SONNET="nvidia_nim/moonshotai/kimi-k2-thinking" MODEL_HAIKU="nvidia_nim/stepfun-ai/step-3.5-flash" MODEL="nvidia_nim/z-ai/glm4.7" # fallback -# Enable for thinking models (kimi, nemotron). Leave false for others (e.g. Mistral). -NIM_ENABLE_THINKING=true +# Global switch for provider reasoning requests and Claude thinking blocks. +ENABLE_THINKING=true ``` @@ -143,6 +143,8 @@ MODEL="nvidia_nim/z-ai/glm4.7" # fallback +> Migration: `NIM_ENABLE_THINKING` was removed in this release. Rename it to `ENABLE_THINKING`. +
Optional Authentication (restrict access to your proxy) @@ -184,6 +186,8 @@ uv run uvicorn server:app --host 0.0.0.0 --port 8082 **Terminal 2:** Run Claude Code: +Point `ANTHROPIC_BASE_URL` at the proxy root URL, not `http://localhost:8082/v1`. + #### Powershell ```powershell $env:ANTHROPIC_AUTH_TOKEN="freecc"; $env:ANTHROPIC_BASE_URL="http://localhost:8082"; claude @@ -277,7 +281,9 @@ free-claude-code # starts the server - **Per-model routing**: Opus / Sonnet / Haiku requests resolve to their model-specific backend, with `MODEL` as fallback - **Request optimization**: 5 categories of trivial requests (quota probes, title generation, prefix detection, suggestions, filepath extraction) are intercepted and responded to locally without using API quota - **Format translation**: Requests are translated from Anthropic format to the provider's OpenAI-compatible format and streamed back -- **Thinking tokens**: `` tags and `reasoning_content` fields are converted into native Claude thinking blocks +- **Thinking tokens**: `` tags and `reasoning_content` fields are converted into native Claude thinking blocks when `ENABLE_THINKING=true` + +The proxy also exposes Claude-compatible probe routes: `GET /v1/models`, `POST /v1/messages`, `POST /v1/messages/count_tokens`, plus `HEAD`/`OPTIONS` support for the common probe endpoints. --- @@ -447,7 +453,7 @@ Configure via `WHISPER_DEVICE` (`cpu` | `cuda` | `nvidia_nim`) and `WHISPER_MODE | `MODEL_SONNET` | Model for Claude Sonnet requests (falls back to `MODEL`) | `open_router/arcee-ai/trinity-large-preview:free` | | `MODEL_HAIKU` | Model for Claude Haiku requests (falls back to `MODEL`) | `open_router/stepfun/step-3.5-flash:free` | | `NVIDIA_NIM_API_KEY` | NVIDIA API key | required for NIM | -| `NIM_ENABLE_THINKING` | Send `chat_template_kwargs` + `reasoning_budget` on NIM requests. Enable for thinking models (kimi, nemotron); leave `false` for others (e.g. Mistral) | `false` | +| `ENABLE_THINKING` | Global switch for provider reasoning requests and Claude thinking blocks. Set `false` to hide thinking across all providers. | `true` | | `OPENROUTER_API_KEY` | OpenRouter API key | required for OpenRouter | | `LM_STUDIO_BASE_URL` | LM Studio server URL | `http://localhost:1234/v1` | | `LLAMACPP_BASE_URL` | llama.cpp server URL | `http://localhost:8080/v1` | diff --git a/api/dependencies.py b/api/dependencies.py index 0131ecc..40a962e 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -39,6 +39,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, + enable_thinking=settings.enable_thinking, ) return NvidiaNimProvider(config, nim_settings=settings.nim) if provider_type == "open_router": @@ -56,6 +57,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, + enable_thinking=settings.enable_thinking, ) return OpenRouterProvider(config) if provider_type == "lmstudio": @@ -68,6 +70,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, + enable_thinking=settings.enable_thinking, ) return LMStudioProvider(config) if provider_type == "llamacpp": @@ -80,6 +83,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, + enable_thinking=settings.enable_thinking, ) return LlamaCppProvider(config) logger.error( diff --git a/api/models/__init__.py b/api/models/__init__.py index 08428a9..e9a90e0 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -14,7 +14,13 @@ from .anthropic import ( TokenCountRequest, Tool, ) -from .responses import MessagesResponse, TokenCountResponse, Usage +from .responses import ( + MessagesResponse, + ModelResponse, + ModelsListResponse, + TokenCountResponse, + Usage, +) __all__ = [ "ContentBlockImage", @@ -25,6 +31,8 @@ __all__ = [ "Message", "MessagesRequest", "MessagesResponse", + "ModelResponse", + "ModelsListResponse", "Role", "SystemContent", "ThinkingConfig", diff --git a/api/models/responses.py b/api/models/responses.py index 40a7d7b..8a238ae 100644 --- a/api/models/responses.py +++ b/api/models/responses.py @@ -11,6 +11,20 @@ class TokenCountResponse(BaseModel): input_tokens: int +class ModelResponse(BaseModel): + created_at: str + display_name: str + id: str + type: Literal["model"] = "model" + + +class ModelsListResponse(BaseModel): + data: list[ModelResponse] + first_id: str | None + has_more: bool + last_id: str | None + + class Usage(BaseModel): input_tokens: int output_tokens: int diff --git a/api/routes.py b/api/routes.py index 67e1e7c..2efdb94 100644 --- a/api/routes.py +++ b/api/routes.py @@ -3,7 +3,7 @@ import traceback import uuid -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import StreamingResponse from loguru import logger @@ -13,13 +13,57 @@ from providers.exceptions import InvalidRequestError, ProviderError from .dependencies import get_provider_for_type, get_settings, require_api_key from .models.anthropic import MessagesRequest, TokenCountRequest -from .models.responses import TokenCountResponse +from .models.responses import ModelResponse, ModelsListResponse, TokenCountResponse from .optimization_handlers import try_optimizations from .request_utils import get_token_count router = APIRouter() +SUPPORTED_CLAUDE_MODELS = [ + ModelResponse( + id="claude-opus-4-20250514", + display_name="Claude Opus 4", + created_at="2025-05-14T00:00:00Z", + ), + ModelResponse( + id="claude-sonnet-4-20250514", + display_name="Claude Sonnet 4", + created_at="2025-05-14T00:00:00Z", + ), + ModelResponse( + id="claude-haiku-4-20250514", + display_name="Claude Haiku 4", + created_at="2025-05-14T00:00:00Z", + ), + ModelResponse( + id="claude-3-opus-20240229", + display_name="Claude 3 Opus", + created_at="2024-02-29T00:00:00Z", + ), + ModelResponse( + id="claude-3-5-sonnet-20241022", + display_name="Claude 3.5 Sonnet", + created_at="2024-10-22T00:00:00Z", + ), + ModelResponse( + id="claude-3-haiku-20240307", + display_name="Claude 3 Haiku", + created_at="2024-03-07T00:00:00Z", + ), + ModelResponse( + id="claude-3-5-haiku-20241022", + display_name="Claude 3.5 Haiku", + created_at="2024-10-22T00:00:00Z", + ), +] + + +def _probe_response(allow: str) -> Response: + """Return an empty success response for compatibility probes.""" + return Response(status_code=204, headers={"Allow": allow}) + + # ============================================================================= # Routes # ============================================================================= @@ -83,6 +127,12 @@ async def create_message( ) from e +@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"]) +async def probe_messages(_auth=Depends(require_api_key)): + """Respond to Claude compatibility probes for the messages endpoint.""" + return _probe_response("POST, HEAD, OPTIONS") + + @router.post("/v1/messages/count_tokens") async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_api_key)): """Count tokens for a request.""" @@ -112,6 +162,12 @@ async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_ap ) from e +@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"]) +async def probe_count_tokens(_auth=Depends(require_api_key)): + """Respond to Claude compatibility probes for the token count endpoint.""" + return _probe_response("POST, HEAD, OPTIONS") + + @router.get("/") async def root( settings: Settings = Depends(get_settings), _auth=Depends(require_api_key) @@ -124,12 +180,35 @@ async def root( } +@router.api_route("/", methods=["HEAD", "OPTIONS"]) +async def probe_root(_auth=Depends(require_api_key)): + """Respond to compatibility probes for the root endpoint.""" + return _probe_response("GET, HEAD, OPTIONS") + + @router.get("/health") async def health(): """Health check endpoint.""" return {"status": "healthy"} +@router.api_route("/health", methods=["HEAD", "OPTIONS"]) +async def probe_health(): + """Respond to compatibility probes for the health endpoint.""" + return _probe_response("GET, HEAD, OPTIONS") + + +@router.get("/v1/models", response_model=ModelsListResponse) +async def list_models(_auth=Depends(require_api_key)): + """List the Claude model ids this proxy advertises for compatibility.""" + return ModelsListResponse( + data=SUPPORTED_CLAUDE_MODELS, + first_id=SUPPORTED_CLAUDE_MODELS[0].id if SUPPORTED_CLAUDE_MODELS else None, + has_more=False, + last_id=SUPPORTED_CLAUDE_MODELS[-1].id if SUPPORTED_CLAUDE_MODELS else None, + ) + + @router.post("/stop") async def stop_cli(request: Request, _auth=Depends(require_api_key)): """Stop all CLI sessions and pending tasks.""" diff --git a/config/nim.py b/config/nim.py index ea2f9f9..8888a17 100644 --- a/config/nim.py +++ b/config/nim.py @@ -21,7 +21,6 @@ class NimSettings(BaseModel): parallel_tool_calls: bool = True ignore_eos: bool = False - enable_thinking: bool = False min_tokens: int = Field(0, ge=0) chat_template: str | None = None diff --git a/config/settings.py b/config/settings.py index b271f6a..f88b432 100644 --- a/config/settings.py +++ b/config/settings.py @@ -3,6 +3,8 @@ import os from functools import lru_cache from pathlib import Path +from collections.abc import Mapping +from typing import Any from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -21,6 +23,58 @@ def _env_files() -> tuple[Path, ...]: return tuple(files) +def _configured_env_files(model_config: Mapping[str, Any]) -> tuple[Path, ...]: + """Return the currently configured env files for Settings.""" + configured = model_config.get("env_file") + if configured is None: + return () + if isinstance(configured, (str, Path)): + return (Path(configured),) + return tuple(Path(item) for item in configured) + + +def _env_file_contains_key(path: Path, key: str) -> bool: + """Check whether a dotenv-style file defines the given key.""" + if not path.is_file(): + return False + + try: + for line in path.read_text(encoding="utf-8").splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if stripped.startswith("export "): + stripped = stripped[7:].lstrip() + name, sep, _value = stripped.partition("=") + if sep and name.strip() == key: + return True + except OSError: + return False + + return False + + +def _removed_env_var_message(model_config: Mapping[str, Any]) -> str | None: + """Return a migration error for removed env vars, if present.""" + removed_key = "NIM_ENABLE_THINKING" + replacement = "ENABLE_THINKING" + + if removed_key in os.environ: + return ( + f"{removed_key} has been removed in this release. " + f"Rename it to {replacement}." + ) + + for env_file in _configured_env_files(model_config): + if _env_file_contains_key(env_file, removed_key): + return ( + f"{removed_key} has been removed in this release. " + f"Rename it to {replacement}. Found in {env_file}." + ) + + return None + + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -67,6 +121,7 @@ class Settings(BaseSettings): provider_max_concurrency: int = Field( default=5, validation_alias="PROVIDER_MAX_CONCURRENCY" ) + enable_thinking: bool = Field(default=True, validation_alias="ENABLE_THINKING") # ==================== HTTP Client Timeouts ==================== http_read_timeout: float = Field( @@ -90,9 +145,6 @@ class Settings(BaseSettings): # ==================== NIM Settings ==================== nim: NimSettings = Field(default_factory=NimSettings) - nim_enable_thinking: bool = Field( - default=False, validation_alias="NIM_ENABLE_THINKING" - ) # ==================== Voice Note Transcription ==================== voice_note_enabled: bool = Field( @@ -131,6 +183,14 @@ class Settings(BaseSettings): default="", validation_alias="ANTHROPIC_AUTH_TOKEN" ) + @model_validator(mode="before") + @classmethod + def reject_removed_env_vars(cls, data: Any) -> Any: + """Fail fast when removed environment variables are still configured.""" + if message := _removed_env_var_message(cls.model_config): + raise ValueError(message) + return data + # Handle empty strings for optional string fields @field_validator( "telegram_bot_token", @@ -140,7 +200,7 @@ class Settings(BaseSettings): mode="before", ) @classmethod - def parse_optional_str(cls, v): + def parse_optional_str(cls, v: Any) -> Any: if v == "": return None return v @@ -174,13 +234,6 @@ class Settings(BaseSettings): ) return v - @model_validator(mode="after") - def _inject_nim_thinking(self) -> Settings: - self.nim = self.nim.model_copy( - update={"enable_thinking": self.nim_enable_thinking} - ) - return self - @model_validator(mode="after") def check_nvidia_nim_api_key(self) -> Settings: if ( diff --git a/providers/base.py b/providers/base.py index 72ba86e..898f751 100644 --- a/providers/base.py +++ b/providers/base.py @@ -22,6 +22,7 @@ class ProviderConfig(BaseModel): http_read_timeout: float = 300.0 http_write_timeout: float = 10.0 http_connect_timeout: float = 2.0 + enable_thinking: bool = True class BaseProvider(ABC): @@ -30,6 +31,16 @@ class BaseProvider(ABC): def __init__(self, config: ProviderConfig): self._config = config + def _is_thinking_enabled(self, request: Any) -> bool: + """Return whether thinking should be enabled for this request.""" + thinking = getattr(request, "thinking", None) + request_enabled = ( + thinking.enabled + if thinking is not None and hasattr(thinking, "enabled") + else True + ) + return self._config.enable_thinking and request_enabled + @abstractmethod async def cleanup(self) -> None: """Release any resources held by this provider.""" diff --git a/providers/common/message_converter.py b/providers/common/message_converter.py index b07fa3b..d2b2881 100644 --- a/providers/common/message_converter.py +++ b/providers/common/message_converter.py @@ -25,6 +25,7 @@ class AnthropicToOpenAIConverter: def convert_messages( messages: list[Any], *, + include_thinking: bool = True, include_reasoning_for_openrouter: bool = False, ) -> list[dict[str, Any]]: """Convert a list of Anthropic messages to OpenAI format. @@ -46,6 +47,7 @@ class AnthropicToOpenAIConverter: result.extend( AnthropicToOpenAIConverter._convert_assistant_message( content, + include_thinking=include_thinking, include_reasoning_for_openrouter=include_reasoning_for_openrouter, ) ) @@ -62,6 +64,7 @@ class AnthropicToOpenAIConverter: def _convert_assistant_message( content: list[Any], *, + include_thinking: bool = True, include_reasoning_for_openrouter: bool = False, ) -> list[dict[str, Any]]: """Convert assistant message blocks, preserving interleaved thinking+text order.""" @@ -75,6 +78,8 @@ class AnthropicToOpenAIConverter: if block_type == "text": content_parts.append(get_block_attr(block, "text", "")) elif block_type == "thinking": + if not include_thinking: + continue thinking = get_block_attr(block, "thinking", "") content_parts.append(f"\n{thinking}\n") if include_reasoning_for_openrouter: @@ -184,6 +189,7 @@ def build_base_request_body( request_data: Any, *, default_max_tokens: int | None = None, + include_thinking: bool = True, include_reasoning_for_openrouter: bool = False, ) -> dict[str, Any]: """Build the common parts of an OpenAI-format request body. @@ -196,6 +202,7 @@ def build_base_request_body( messages = AnthropicToOpenAIConverter.convert_messages( request_data.messages, + include_thinking=include_thinking, include_reasoning_for_openrouter=include_reasoning_for_openrouter, ) diff --git a/providers/llamacpp/client.py b/providers/llamacpp/client.py index aa86832..8a4ee82 100644 --- a/providers/llamacpp/client.py +++ b/providers/llamacpp/client.py @@ -56,6 +56,7 @@ class LlamaCppProvider(BaseProvider): """Stream response natively via Llama.cpp's Anthropic-compatible endpoint.""" tag = self._provider_name req_tag = f" request_id={request_id}" if request_id else "" + thinking_enabled = self._is_thinking_enabled(request) # Dump the Anthropic Pydantic model directly into a dict body = request.model_dump(exclude_none=True) @@ -68,7 +69,11 @@ class LlamaCppProvider(BaseProvider): # Translate internal ThinkingConfig to Anthropic API schema if "thinking" in body: thinking_cfg = body.pop("thinking") - if isinstance(thinking_cfg, dict) and thinking_cfg.get("enabled"): + if ( + thinking_enabled + and isinstance(thinking_cfg, dict) + and thinking_cfg.get("enabled") + ): # Anthropic API requires a budget_tokens value when enabled body["thinking"] = {"type": "enabled"} @@ -126,9 +131,15 @@ class LlamaCppProvider(BaseProvider): except Exception as e: logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e) mapped_e = map_error(e) - error_message = get_user_facing_error_message( - mapped_e, read_timeout_s=self._config.http_read_timeout - ) + if getattr(mapped_e, "status_code", None) == 405: + error_message = ( + f"Upstream provider {tag} rejected the request method " + "or endpoint (HTTP 405)." + ) + else: + error_message = get_user_facing_error_message( + mapped_e, read_timeout_s=self._config.http_read_timeout + ) if request_id: error_message += f"\nRequest ID: {request_id}" diff --git a/providers/lmstudio/client.py b/providers/lmstudio/client.py index 7b5dbdf..88ed2c2 100644 --- a/providers/lmstudio/client.py +++ b/providers/lmstudio/client.py @@ -56,6 +56,7 @@ class LMStudioProvider(BaseProvider): """Stream response natively via LM Studio's Anthropic-compatible endpoint.""" tag = self._provider_name req_tag = f" request_id={request_id}" if request_id else "" + thinking_enabled = self._is_thinking_enabled(request) # Dump the Anthropic Pydantic model directly into a dict body = request.model_dump(exclude_none=True) @@ -68,7 +69,11 @@ class LMStudioProvider(BaseProvider): # Translate internal ThinkingConfig to Anthropic API schema if "thinking" in body: thinking_cfg = body.pop("thinking") - if isinstance(thinking_cfg, dict) and thinking_cfg.get("enabled"): + if ( + thinking_enabled + and isinstance(thinking_cfg, dict) + and thinking_cfg.get("enabled") + ): # Anthropic API requires a budget_tokens value when enabled body["thinking"] = {"type": "enabled"} @@ -126,9 +131,15 @@ class LMStudioProvider(BaseProvider): except Exception as e: logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e) mapped_e = map_error(e) - error_message = get_user_facing_error_message( - mapped_e, read_timeout_s=self._config.http_read_timeout - ) + if getattr(mapped_e, "status_code", None) == 405: + error_message = ( + f"Upstream provider {tag} rejected the request method " + "or endpoint (HTTP 405)." + ) + else: + error_message = get_user_facing_error_message( + mapped_e, read_timeout_s=self._config.http_read_timeout + ) if request_id: error_message += f"\nRequest ID: {request_id}" diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index cd57863..47d0300 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -25,4 +25,8 @@ class NvidiaNimProvider(OpenAICompatibleProvider): def _build_request_body(self, request: Any) -> dict: """Internal helper for tests and shared building.""" - return build_request_body(request, self._nim_settings) + return build_request_body( + request, + self._nim_settings, + thinking_enabled=self._is_thinking_enabled(request), + ) diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index 067dfd9..c82719e 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -21,14 +21,19 @@ def _set_extra( extra_body[key] = value -def build_request_body(request_data: Any, nim: NimSettings) -> dict: +def build_request_body( + request_data: Any, nim: NimSettings, *, thinking_enabled: bool +) -> dict: """Build OpenAI-format request body from Anthropic request.""" logger.debug( "NIM_REQUEST: conversion start model={} msgs={}", getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - body = build_base_request_body(request_data) + body = build_base_request_body( + request_data, + include_thinking=thinking_enabled, + ) # NIM-specific max_tokens: cap against nim.max_tokens max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None) @@ -63,7 +68,7 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict: if request_extra: extra_body.update(request_extra) - if nim.enable_thinking: + if thinking_enabled: extra_body.setdefault( "chat_template_kwargs", {"thinking": True, "enable_thinking": True} ) diff --git a/providers/open_router/client.py b/providers/open_router/client.py index a55d18a..af04f22 100644 --- a/providers/open_router/client.py +++ b/providers/open_router/client.py @@ -25,10 +25,17 @@ class OpenRouterProvider(OpenAICompatibleProvider): def _build_request_body(self, request: Any) -> dict: """Internal helper for tests and shared building.""" - return build_request_body(request) + return build_request_body( + request, + thinking_enabled=self._is_thinking_enabled(request), + ) - def _handle_extra_reasoning(self, delta: Any, sse: SSEBuilder) -> Iterator[str]: + def _handle_extra_reasoning( + self, delta: Any, sse: SSEBuilder, *, thinking_enabled: bool + ) -> Iterator[str]: """Handle reasoning_details for StepFun models.""" + if not thinking_enabled: + return reasoning_details = getattr(delta, "reasoning_details", None) if reasoning_details and isinstance(reasoning_details, list): for item in reasoning_details: diff --git a/providers/open_router/request.py b/providers/open_router/request.py index bbd2ad0..c524a00 100644 --- a/providers/open_router/request.py +++ b/providers/open_router/request.py @@ -9,7 +9,7 @@ from providers.common.message_converter import build_base_request_body OPENROUTER_DEFAULT_MAX_TOKENS = 81920 -def build_request_body(request_data: Any) -> dict: +def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict: """Build OpenAI-format request body from Anthropic request for OpenRouter.""" logger.debug( "OPENROUTER_REQUEST: conversion start model={} msgs={}", @@ -18,8 +18,9 @@ def build_request_body(request_data: Any) -> dict: ) body = build_base_request_body( request_data, + include_thinking=thinking_enabled, default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS, - include_reasoning_for_openrouter=True, + include_reasoning_for_openrouter=thinking_enabled, ) # OpenRouter reasoning: extra_body={"reasoning": {"enabled": True}} @@ -28,10 +29,6 @@ def build_request_body(request_data: Any) -> dict: if request_extra: extra_body.update(request_extra) - thinking = getattr(request_data, "thinking", None) - thinking_enabled = ( - thinking.enabled if thinking and hasattr(thinking, "enabled") else True - ) if thinking_enabled: extra_body.setdefault("reasoning", {"enabled": True}) diff --git a/providers/openai_compat.py b/providers/openai_compat.py index fcc6ba9..8424481 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -66,7 +66,9 @@ class OpenAICompatibleProvider(BaseProvider): def _build_request_body(self, request: Any) -> dict: """Build request body. Must be implemented by subclasses.""" - def _handle_extra_reasoning(self, delta: Any, sse: SSEBuilder) -> Iterator[str]: + def _handle_extra_reasoning( + self, delta: Any, sse: SSEBuilder, *, thinking_enabled: bool + ) -> Iterator[str]: """Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details).""" return iter(()) @@ -151,6 +153,7 @@ class OpenAICompatibleProvider(BaseProvider): think_parser = ThinkTagParser() heuristic_parser = HeuristicToolParser() + thinking_enabled = self._is_thinking_enabled(request) finish_reason = None usage_info = None @@ -180,19 +183,25 @@ class OpenAICompatibleProvider(BaseProvider): # Handle reasoning_content (OpenAI extended format) reasoning = getattr(delta, "reasoning_content", None) - if reasoning: + if thinking_enabled and reasoning: for event in sse.ensure_thinking_block(): yield event yield sse.emit_thinking_delta(reasoning) # Provider-specific extra reasoning (e.g. OpenRouter reasoning_details) - for event in self._handle_extra_reasoning(delta, sse): + for event in self._handle_extra_reasoning( + delta, + sse, + thinking_enabled=thinking_enabled, + ): yield event # Handle text content if delta.content: for part in think_parser.feed(delta.content): if part.type == ContentType.THINKING: + if not thinking_enabled: + continue for event in sse.ensure_thinking_block(): yield event yield sse.emit_thinking_delta(part.content) @@ -248,12 +257,16 @@ class OpenAICompatibleProvider(BaseProvider): logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e) mapped_e = map_error(e) error_occurred = True - error_message = append_request_id( - get_user_facing_error_message( + if getattr(mapped_e, "status_code", None) == 405: + base_message = ( + f"Upstream provider {tag} rejected the request method " + "or endpoint (HTTP 405)." + ) + else: + base_message = get_user_facing_error_message( mapped_e, read_timeout_s=self._config.http_read_timeout - ), - request_id, - ) + ) + error_message = append_request_id(base_message, request_id) logger.info( "{}_STREAM: Emitting SSE error event for {}{}", tag, @@ -269,10 +282,13 @@ class OpenAICompatibleProvider(BaseProvider): remaining = think_parser.flush() if remaining: if remaining.type == ContentType.THINKING: - for event in sse.ensure_thinking_block(): - yield event - yield sse.emit_thinking_delta(remaining.content) - else: + if not thinking_enabled: + remaining = None + else: + for event in sse.ensure_thinking_block(): + yield event + yield sse.emit_thinking_delta(remaining.content) + if remaining and remaining.type == ContentType.TEXT: for event in sse.ensure_text_block(): yield event yield sse.emit_text_delta(remaining.content) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 3b3305a..9f83a64 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -40,6 +40,34 @@ def test_health(): assert response.json()["status"] == "healthy" +def test_models_list(): + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["has_more"] is False + ids = [item["id"] for item in data["data"]] + assert "claude-sonnet-4-20250514" in ids + assert data["first_id"] == ids[0] + assert data["last_id"] == ids[-1] + + +def test_probe_endpoints_return_204_with_allow_headers(): + responses = [ + client.head("/"), + client.options("/"), + client.head("/health"), + client.options("/health"), + client.head("/v1/messages"), + client.options("/v1/messages"), + client.head("/v1/messages/count_tokens"), + client.options("/v1/messages/count_tokens"), + ] + + for response in responses: + assert response.status_code == 204 + assert "Allow" in response.headers + + def test_create_message_stream(): """Create message returns streaming response.""" payload = { diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 02e6ca3..e661873 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -55,3 +55,19 @@ def test_anthropic_auth_token_accepts_bearer_authorization(): assert r.json()["input_tokens"] == 2 app.dependency_overrides.clear() + + +def test_anthropic_auth_token_applies_to_models_endpoint(): + client = TestClient(app) + settings = Settings() + settings.anthropic_auth_token = "models-token" + app.dependency_overrides[get_settings] = lambda: settings + + r = client.get("/v1/models") + assert r.status_code == 401 + + r = client.get("/v1/models", headers={"X-API-Key": "models-token"}) + assert r.status_code == 200 + assert "data" in r.json() + + app.dependency_overrides.clear() diff --git a/tests/config/test_config.py b/tests/config/test_config.py index cf9e751..35189c9 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -25,6 +25,7 @@ class TestSettings: assert isinstance(settings.provider_rate_window, int) assert isinstance(settings.nim.temperature, float) assert isinstance(settings.fast_prefix_detection, bool) + assert isinstance(settings.enable_thinking, bool) def test_get_settings_cached(self): """Test get_settings returns cached instance.""" @@ -104,6 +105,22 @@ class TestSettings: settings = Settings() assert settings.http_connect_timeout == 5.0 + def test_enable_thinking_from_env(self, monkeypatch): + """ENABLE_THINKING env var is loaded into settings.""" + from config.settings import Settings + + monkeypatch.setenv("ENABLE_THINKING", "false") + settings = Settings() + assert settings.enable_thinking is False + + def test_removed_nim_enable_thinking_raises(self, monkeypatch): + """NIM_ENABLE_THINKING now fails fast with a migration message.""" + from config.settings import Settings + + monkeypatch.setenv("NIM_ENABLE_THINKING", "false") + with pytest.raises(ValidationError, match="Rename it to ENABLE_THINKING"): + Settings() + # --- NimSettings Validation Tests --- class TestNimSettingsValidBounds: @@ -228,6 +245,13 @@ class TestNimSettingsValidators: with pytest.raises(ValidationError): NimSettings(**cast(Any, {"unknown_field": "value"})) + def test_enable_thinking_field_removed(self): + """NimSettings no longer accepts the removed thinking toggle.""" + from typing import Any, cast + + with pytest.raises(ValidationError): + NimSettings(**cast(Any, {"enable_thinking": True})) + class TestSettingsOptionalStr: """Test Settings parse_optional_str validator.""" diff --git a/tests/conftest.py b/tests/conftest.py index b0928f5..1967be2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,13 @@ import asyncio import contextlib import logging import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock import pytest +from config.settings import Settings + # Set mock environment BEFORE any imports that use Settings os.environ.setdefault("NVIDIA_NIM_API_KEY", "test_key") os.environ.setdefault("MODEL", "nvidia_nim/test-model") @@ -13,26 +17,12 @@ os.environ["PTB_TIMEDELTA"] = "1" # (tests expect endpoints to be unauthenticated by default) os.environ["ANTHROPIC_AUTH_TOKEN"] = "" -from typing import Any -from unittest.mock import AsyncMock, MagicMock - -from config.nim import NimSettings -from messaging.models import IncomingMessage -from messaging.platforms.base import ( - CLISession, - MessagingPlatform, - SessionManagerInterface, -) -from messaging.session import SessionStore -from providers.base import ProviderConfig -from providers.nvidia_nim import NvidiaNimProvider +Settings.model_config = {**Settings.model_config, "env_file": None} @pytest.fixture(autouse=True) def _isolate_from_dotenv(monkeypatch): """Prevent Pydantic BaseSettings from reading the .env file during tests.""" - from config.settings import Settings - monkeypatch.setattr( Settings, "model_config", {**Settings.model_config, "env_file": None} ) @@ -40,6 +30,8 @@ def _isolate_from_dotenv(monkeypatch): @pytest.fixture def provider_config(): + from providers.base import ProviderConfig + return ProviderConfig( api_key="test_key", base_url="https://test.api.nvidia.com/v1", @@ -50,6 +42,9 @@ def provider_config(): @pytest.fixture def nim_provider(provider_config): + from config.nim import NimSettings + from providers.nvidia_nim import NvidiaNimProvider + return NvidiaNimProvider(provider_config, nim_settings=NimSettings()) @@ -62,6 +57,7 @@ def open_router_provider(provider_config): @pytest.fixture def lmstudio_provider(provider_config): + from providers.base import ProviderConfig from providers.lmstudio import LMStudioProvider lmstudio_config = ProviderConfig( @@ -75,6 +71,7 @@ def lmstudio_provider(provider_config): @pytest.fixture def llamacpp_provider(provider_config): + from providers.base import ProviderConfig from providers.llamacpp import LlamaCppProvider llamacpp_config = ProviderConfig( @@ -88,6 +85,8 @@ def llamacpp_provider(provider_config): @pytest.fixture def mock_cli_session(): + from messaging.platforms.base import CLISession + session = MagicMock(spec=CLISession) session.start_task = MagicMock() # This will return an async generator session.is_busy = False @@ -96,6 +95,8 @@ def mock_cli_session(): @pytest.fixture def mock_cli_manager(): + from messaging.platforms.base import SessionManagerInterface + manager = MagicMock(spec=SessionManagerInterface) manager.get_or_create_session = AsyncMock() manager.register_real_session_id = AsyncMock(return_value=True) @@ -107,6 +108,8 @@ def mock_cli_manager(): @pytest.fixture def mock_platform(): + from messaging.platforms.base import MessagingPlatform + platform = MagicMock(spec=MessagingPlatform) platform.send_message = AsyncMock(return_value="msg_123") platform.edit_message = AsyncMock() @@ -127,6 +130,8 @@ def mock_platform(): @pytest.fixture def mock_session_store(): + from messaging.session import SessionStore + store = MagicMock(spec=SessionStore) store.save_tree = MagicMock() store.get_tree = MagicMock(return_value=None) @@ -156,6 +161,8 @@ def incoming_message_factory(): ) def _create(**kwargs): + from messaging.models import IncomingMessage + defaults: dict[str, Any] = { "text": "hello", "chat_id": "chat_1", diff --git a/tests/providers/test_converter.py b/tests/providers/test_converter.py index 9fa3385..701fe3f 100644 --- a/tests/providers/test_converter.py +++ b/tests/providers/test_converter.py @@ -199,6 +199,24 @@ def test_convert_assistant_message_thinking_include_reasoning_for_openrouter(): assert "" in result[0]["content"] +def test_convert_assistant_message_thinking_removed_when_disabled(): + content = [ + MockBlock(type="thinking", thinking="I need to calculate this."), + MockBlock(type="text", text="The answer is 4."), + ] + messages = [MockMessage("assistant", content)] + result = AnthropicToOpenAIConverter.convert_messages( + messages, + include_thinking=False, + include_reasoning_for_openrouter=True, + ) + + assert len(result) == 1 + assert "reasoning_content" not in result[0] + assert "" not in result[0]["content"] + assert result[0]["content"] == "The answer is 4." + + def test_convert_assistant_message_tool_use(): content = [ MockBlock(type="text", text="I will call the tool."), diff --git a/tests/providers/test_llamacpp.py b/tests/providers/test_llamacpp.py index 97463b8..e1a9f05 100644 --- a/tests/providers/test_llamacpp.py +++ b/tests/providers/test_llamacpp.py @@ -110,6 +110,37 @@ def test_init_base_url_strips_trailing_slash(): assert provider._base_url == "http://localhost:8080/v1" +@pytest.mark.asyncio +async def test_stream_response_omits_thinking_when_globally_disabled(llamacpp_config): + provider = LlamaCppProvider( + llamacpp_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest() + + mock_response = MagicMock() + mock_response.status_code = 200 + + async def empty_aiter(): + if False: + yield "" + + mock_response.aiter_lines = empty_aiter + + with ( + patch.object(provider._client, "build_request") as mock_build, + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + [e async for e in provider.stream_response(req)] + + _, kwargs = mock_build.call_args + assert "thinking" not in kwargs["json"] + + @pytest.mark.asyncio async def test_stream_response(llamacpp_provider): """Test streaming native Anthropic response.""" @@ -254,3 +285,38 @@ async def test_stream_network_error(llamacpp_provider): assert events[0].startswith("event: error\ndata: {") assert "Connection refused" in events[0] assert "TEST_ID2" in events[0] + + +@pytest.mark.asyncio +async def test_stream_error_405_mentions_upstream_provider(llamacpp_provider): + req = MockRequest() + + mock_response = MagicMock() + mock_response.status_code = 405 + mock_response.aread = AsyncMock(return_value=b"Method Not Allowed") + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "Method Not Allowed", request=MagicMock(), response=mock_response + ) + ) + + with ( + patch.object( + llamacpp_provider._client, "build_request", return_value=MagicMock() + ), + patch.object( + llamacpp_provider._client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + events = [ + e async for e in llamacpp_provider.stream_response(req, request_id="REQ405") + ] + + assert ( + "Upstream provider LLAMACPP rejected the request method or endpoint (HTTP 405)." + in events[0] + ) + assert "REQ405" in events[0] diff --git a/tests/providers/test_lmstudio.py b/tests/providers/test_lmstudio.py index 201e591..1173846 100644 --- a/tests/providers/test_lmstudio.py +++ b/tests/providers/test_lmstudio.py @@ -110,6 +110,68 @@ def test_init_base_url_strips_trailing_slash(): assert provider._base_url == "http://localhost:1234/v1" +@pytest.mark.asyncio +async def test_stream_response_omits_thinking_when_globally_disabled(lmstudio_config): + provider = LMStudioProvider( + lmstudio_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest() + + mock_response = MagicMock() + mock_response.status_code = 200 + + async def empty_aiter(): + if False: + yield "" + + mock_response.aiter_lines = empty_aiter + + with ( + patch.object(provider._client, "build_request") as mock_build, + patch.object( + provider._client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + [e async for e in provider.stream_response(req)] + + _, kwargs = mock_build.call_args + assert "thinking" not in kwargs["json"] + + +@pytest.mark.asyncio +async def test_stream_response_omits_thinking_when_request_disables_it( + lmstudio_provider, +): + req = MockRequest() + req.thinking.enabled = False + + mock_response = MagicMock() + mock_response.status_code = 200 + + async def empty_aiter(): + if False: + yield "" + + mock_response.aiter_lines = empty_aiter + + with ( + patch.object(lmstudio_provider._client, "build_request") as mock_build, + patch.object( + lmstudio_provider._client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + [e async for e in lmstudio_provider.stream_response(req)] + + _, kwargs = mock_build.call_args + assert "thinking" not in kwargs["json"] + + @pytest.mark.asyncio async def test_stream_response(lmstudio_provider): """Test streaming native Anthropic response.""" @@ -254,3 +316,38 @@ async def test_stream_network_error(lmstudio_provider): assert events[0].startswith("event: error\ndata: {") assert "Connection refused" in events[0] assert "TEST_ID2" in events[0] + + +@pytest.mark.asyncio +async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider): + req = MockRequest() + + mock_response = MagicMock() + mock_response.status_code = 405 + mock_response.aread = AsyncMock(return_value=b"Method Not Allowed") + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "Method Not Allowed", request=MagicMock(), response=mock_response + ) + ) + + with ( + patch.object( + lmstudio_provider._client, "build_request", return_value=MagicMock() + ), + patch.object( + lmstudio_provider._client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + events = [ + e async for e in lmstudio_provider.stream_response(req, request_id="REQ405") + ] + + assert ( + "Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)." + in events[0] + ) + assert "REQ405" in events[0] diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py index a4a136e..3f15095 100644 --- a/tests/providers/test_nvidia_nim.py +++ b/tests/providers/test_nvidia_nim.py @@ -91,9 +91,7 @@ async def test_build_request_body(provider_config): """Test request body construction.""" from config.nim import NimSettings - provider = NvidiaNimProvider( - provider_config, nim_settings=NimSettings(enable_thinking=True) - ) + provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) req = MockRequest() body = provider._build_request_body(req) @@ -110,6 +108,40 @@ async def test_build_request_body(provider_config): assert body["extra_body"]["reasoning_budget"] == body["max_tokens"] +@pytest.mark.asyncio +async def test_build_request_body_omits_reasoning_when_globally_disabled( + provider_config, +): + from config.nim import NimSettings + + provider = NvidiaNimProvider( + provider_config.model_copy(update={"enable_thinking": False}), + nim_settings=NimSettings(), + ) + req = MockRequest() + body = provider._build_request_body(req) + + extra = body.get("extra_body", {}) + assert "chat_template_kwargs" not in extra + assert "reasoning_budget" not in extra + + +@pytest.mark.asyncio +async def test_build_request_body_omits_reasoning_when_request_disables_thinking( + provider_config, +): + from config.nim import NimSettings + + provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings()) + req = MockRequest() + req.thinking.enabled = False + body = provider._build_request_body(req) + + extra = body.get("extra_body", {}) + assert "chat_template_kwargs" not in extra + assert "reasoning_budget" not in extra + + @pytest.mark.asyncio async def test_stream_response_text(nim_provider): """Test streaming text response.""" @@ -195,6 +227,44 @@ async def test_stream_response_thinking_reasoning_content(nim_provider): assert found_thinking +@pytest.mark.asyncio +async def test_stream_response_suppresses_thinking_when_disabled(provider_config): + from config.nim import NimSettings + + provider = NvidiaNimProvider( + provider_config.model_copy(update={"enable_thinking": False}), + nim_settings=NimSettings(), + ) + req = MockRequest() + + mock_chunk = MagicMock() + mock_chunk.choices = [ + MagicMock( + delta=MagicMock( + content="secretAnswer", reasoning_content="Thinking..." + ), + finish_reason="stop", + ) + ] + mock_chunk.usage = None + + async def mock_stream(): + yield mock_chunk + + with patch.object( + provider._client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.return_value = mock_stream() + + events = [e async for e in provider.stream_response(req)] + + event_text = "".join(events) + assert "thinking_delta" not in event_text + assert "Thinking..." not in event_text + assert "secret" not in event_text + assert "Answer" in event_text + + @pytest.mark.asyncio async def test_tool_call_stream(nim_provider): """Test streaming tool calls.""" diff --git a/tests/providers/test_nvidia_nim_request.py b/tests/providers/test_nvidia_nim_request.py index 8545423..b7508d6 100644 --- a/tests/providers/test_nvidia_nim_request.py +++ b/tests/providers/test_nvidia_nim_request.py @@ -67,21 +67,21 @@ class TestBuildRequestBody: def test_max_tokens_capped_by_nim(self, req): req.max_tokens = 100000 nim = NimSettings(max_tokens=4096) - body = build_request_body(req, nim) + body = build_request_body(req, nim, thinking_enabled=True) assert body["max_tokens"] == 4096 def test_presence_penalty_included_when_nonzero(self, req): nim = NimSettings(presence_penalty=0.5) - body = build_request_body(req, nim) + body = build_request_body(req, nim, thinking_enabled=True) assert body["presence_penalty"] == 0.5 def test_include_stop_str_in_output_not_sent(self, req): - body = build_request_body(req, NimSettings()) + body = build_request_body(req, NimSettings(), thinking_enabled=True) assert "include_stop_str_in_output" not in body.get("extra_body", {}) def test_parallel_tool_calls_included(self, req): nim = NimSettings(parallel_tool_calls=False) - body = build_request_body(req, nim) + body = build_request_body(req, nim, thinking_enabled=True) assert body["parallel_tool_calls"] is False def test_reasoning_params_in_extra_body(self): @@ -98,8 +98,8 @@ class TestBuildRequestBody: req.extra_body = None req.top_k = None - nim = NimSettings(enable_thinking=True) - body = build_request_body(req, nim) + nim = NimSettings() + body = build_request_body(req, nim, thinking_enabled=True) extra = body["extra_body"] assert extra["chat_template_kwargs"] == { "thinking": True, @@ -121,8 +121,8 @@ class TestBuildRequestBody: req.extra_body = None req.top_k = None - nim = NimSettings(enable_thinking=False) - body = build_request_body(req, nim) + nim = NimSettings() + body = build_request_body(req, nim, thinking_enabled=False) extra = body.get("extra_body", {}) assert "chat_template_kwargs" not in extra assert "reasoning_budget" not in extra @@ -142,7 +142,7 @@ class TestBuildRequestBody: req.top_k = None nim = NimSettings() - body = build_request_body(req, nim) + body = build_request_body(req, nim, thinking_enabled=False) extra = body.get("extra_body", {}) for param in ( "thinking", @@ -152,3 +152,29 @@ class TestBuildRequestBody: "reasoning_effort", ): assert param not in extra + + def test_assistant_thinking_blocks_removed_when_disabled(self): + req = MagicMock() + req.model = "test" + req.messages = [ + MagicMock( + role="assistant", + content=[ + MagicMock(type="thinking", thinking="secret"), + MagicMock(type="text", text="answer"), + ], + ) + ] + req.max_tokens = 100 + req.system = None + req.temperature = None + req.top_p = None + req.stop_sequences = None + req.tools = None + req.tool_choice = None + req.extra_body = None + req.top_k = None + + body = build_request_body(req, NimSettings(), thinking_enabled=False) + assert "" not in body["messages"][0]["content"] + assert "answer" in body["messages"][0]["content"] diff --git a/tests/providers/test_open_router.py b/tests/providers/test_open_router.py index 0f801a7..fbcfd47 100644 --- a/tests/providers/test_open_router.py +++ b/tests/providers/test_open_router.py @@ -105,6 +105,26 @@ def test_build_request_body_has_reasoning_extra(open_router_provider): assert body["extra_body"]["reasoning"]["enabled"] is True +def test_build_request_body_omits_reasoning_when_globally_disabled(open_router_config): + provider = OpenRouterProvider( + open_router_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest() + body = provider._build_request_body(req) + + assert "extra_body" not in body or "reasoning" not in body["extra_body"] + + +def test_build_request_body_omits_reasoning_when_request_disables_thinking( + open_router_provider, +): + req = MockRequest() + req.thinking.enabled = False + body = open_router_provider._build_request_body(req) + + assert "extra_body" not in body or "reasoning" not in body["extra_body"] + + def test_build_request_body_base_url_and_model(open_router_provider): """Base URL and model are correct in provider config.""" assert open_router_provider._base_url == "https://openrouter.ai/api/v1" @@ -205,6 +225,44 @@ async def test_stream_response_reasoning_content(open_router_provider): assert found_thinking +@pytest.mark.asyncio +async def test_stream_response_suppresses_reasoning_when_disabled(open_router_config): + provider = OpenRouterProvider( + open_router_config.model_copy(update={"enable_thinking": False}) + ) + req = MockRequest() + + mock_chunk = MagicMock() + mock_chunk.choices = [ + MagicMock( + delta=MagicMock( + content="secretAnswer", + reasoning_content="Thinking...", + reasoning_details=[{"text": "Step 1"}], + ), + finish_reason="stop", + ) + ] + mock_chunk.usage = None + + async def mock_stream(): + yield mock_chunk + + with patch.object( + provider._client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.return_value = mock_stream() + + events = [e async for e in provider.stream_response(req)] + + event_text = "".join(events) + assert "thinking_delta" not in event_text + assert "Thinking..." not in event_text + assert "Step 1" not in event_text + assert "secret" not in event_text + assert "Answer" in event_text + + @pytest.mark.asyncio async def test_stream_response_empty_choices_skipped(open_router_provider): """Chunks with empty choices are skipped.""" diff --git a/tests/providers/test_streaming_errors.py b/tests/providers/test_streaming_errors.py index 21fce7c..04e9ee9 100644 --- a/tests/providers/test_streaming_errors.py +++ b/tests/providers/test_streaming_errors.py @@ -39,6 +39,18 @@ def _make_provider(): return NvidiaNimProvider(config, nim_settings=NimSettings()) +def _make_provider_with_thinking_enabled(enabled: bool): + """Create a provider instance with thinking explicitly enabled or disabled.""" + config = ProviderConfig( + api_key="test_key", + base_url="https://test.api.nvidia.com/v1", + rate_limit=10, + rate_window=60, + enable_thinking=enabled, + ) + return NvidiaNimProvider(config, nim_settings=NimSettings()) + + def _make_request(model="test-model", stream=True): """Create a mock request with all fields build_request_body needs.""" req = MagicMock() @@ -272,6 +284,76 @@ class TestStreamingExceptionHandling: assert "I think..." in event_text assert "The answer" in event_text + @pytest.mark.asyncio + async def test_stream_with_reasoning_content_suppressed_when_disabled(self): + """reasoning deltas are stripped while normal text still streams.""" + provider = _make_provider_with_thinking_enabled(False) + request = _make_request() + + chunk1 = _make_chunk(reasoning_content="I think...") + chunk2 = _make_chunk(content="secretThe answer") + chunk3 = _make_chunk(finish_reason="stop") + stream_mock = AsyncStreamMock([chunk1, chunk2, chunk3]) + + with ( + patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + return_value=stream_mock, + ), + patch.object( + provider._global_rate_limiter, + "wait_if_blocked", + new_callable=AsyncMock, + return_value=False, + ), + ): + events = await _collect_stream(provider, request) + + event_text = "".join(events) + assert "thinking_delta" not in event_text + assert "I think..." not in event_text + assert "secret" not in event_text + assert "The answer" in event_text + + @pytest.mark.asyncio + async def test_stream_with_upstream_405_mentions_provider_name(self): + """HTTP 405s are surfaced as upstream method/endpoint rejections.""" + provider = _make_provider() + request = _make_request() + + response = httpx.Response( + status_code=405, + request=httpx.Request("POST", "https://example.com/v1/chat/completions"), + ) + error = httpx.HTTPStatusError( + "Method Not Allowed", + request=response.request, + response=response, + ) + + with patch.object( + provider._client.chat.completions, + "create", + new_callable=AsyncMock, + side_effect=error, + ): + events = [ + e + async for e in provider.stream_response( + request, + request_id="REQ405", + ) + ] + + event_text = "".join(events) + assert ( + "Upstream provider NIM rejected the request method or endpoint (HTTP 405)." + in event_text + ) + assert "request_id=REQ405" in event_text + @pytest.mark.asyncio async def test_stream_rate_limited_retries_via_execute_with_retry(self): """When rate limited, execute_with_retry handles retries transparently."""