mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Fixes for issue 113 and 116
This commit is contained in:
parent
7468f53ab7
commit
835d0454e8
28 changed files with 807 additions and 83 deletions
|
|
@ -40,6 +40,34 @@ def test_health():
|
|||
assert response.json()["status"] == "healthy"
|
||||
|
||||
|
||||
def test_models_list():
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["has_more"] is False
|
||||
ids = [item["id"] for item in data["data"]]
|
||||
assert "claude-sonnet-4-20250514" in ids
|
||||
assert data["first_id"] == ids[0]
|
||||
assert data["last_id"] == ids[-1]
|
||||
|
||||
|
||||
def test_probe_endpoints_return_204_with_allow_headers():
|
||||
responses = [
|
||||
client.head("/"),
|
||||
client.options("/"),
|
||||
client.head("/health"),
|
||||
client.options("/health"),
|
||||
client.head("/v1/messages"),
|
||||
client.options("/v1/messages"),
|
||||
client.head("/v1/messages/count_tokens"),
|
||||
client.options("/v1/messages/count_tokens"),
|
||||
]
|
||||
|
||||
for response in responses:
|
||||
assert response.status_code == 204
|
||||
assert "Allow" in response.headers
|
||||
|
||||
|
||||
def test_create_message_stream():
|
||||
"""Create message returns streaming response."""
|
||||
payload = {
|
||||
|
|
|
|||
|
|
@ -55,3 +55,19 @@ def test_anthropic_auth_token_accepts_bearer_authorization():
|
|||
assert r.json()["input_tokens"] == 2
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_anthropic_auth_token_applies_to_models_endpoint():
|
||||
client = TestClient(app)
|
||||
settings = Settings()
|
||||
settings.anthropic_auth_token = "models-token"
|
||||
app.dependency_overrides[get_settings] = lambda: settings
|
||||
|
||||
r = client.get("/v1/models")
|
||||
assert r.status_code == 401
|
||||
|
||||
r = client.get("/v1/models", headers={"X-API-Key": "models-token"})
|
||||
assert r.status_code == 200
|
||||
assert "data" in r.json()
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class TestSettings:
|
|||
assert isinstance(settings.provider_rate_window, int)
|
||||
assert isinstance(settings.nim.temperature, float)
|
||||
assert isinstance(settings.fast_prefix_detection, bool)
|
||||
assert isinstance(settings.enable_thinking, bool)
|
||||
|
||||
def test_get_settings_cached(self):
|
||||
"""Test get_settings returns cached instance."""
|
||||
|
|
@ -104,6 +105,22 @@ class TestSettings:
|
|||
settings = Settings()
|
||||
assert settings.http_connect_timeout == 5.0
|
||||
|
||||
def test_enable_thinking_from_env(self, monkeypatch):
|
||||
"""ENABLE_THINKING env var is loaded into settings."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("ENABLE_THINKING", "false")
|
||||
settings = Settings()
|
||||
assert settings.enable_thinking is False
|
||||
|
||||
def test_removed_nim_enable_thinking_raises(self, monkeypatch):
|
||||
"""NIM_ENABLE_THINKING now fails fast with a migration message."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("NIM_ENABLE_THINKING", "false")
|
||||
with pytest.raises(ValidationError, match="Rename it to ENABLE_THINKING"):
|
||||
Settings()
|
||||
|
||||
|
||||
# --- NimSettings Validation Tests ---
|
||||
class TestNimSettingsValidBounds:
|
||||
|
|
@ -228,6 +245,13 @@ class TestNimSettingsValidators:
|
|||
with pytest.raises(ValidationError):
|
||||
NimSettings(**cast(Any, {"unknown_field": "value"}))
|
||||
|
||||
def test_enable_thinking_field_removed(self):
|
||||
"""NimSettings no longer accepts the removed thinking toggle."""
|
||||
from typing import Any, cast
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
NimSettings(**cast(Any, {"enable_thinking": True}))
|
||||
|
||||
|
||||
class TestSettingsOptionalStr:
|
||||
"""Test Settings parse_optional_str validator."""
|
||||
|
|
|
|||
|
|
@ -2,9 +2,13 @@ import asyncio
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from config.settings import Settings
|
||||
|
||||
# Set mock environment BEFORE any imports that use Settings
|
||||
os.environ.setdefault("NVIDIA_NIM_API_KEY", "test_key")
|
||||
os.environ.setdefault("MODEL", "nvidia_nim/test-model")
|
||||
|
|
@ -13,26 +17,12 @@ os.environ["PTB_TIMEDELTA"] = "1"
|
|||
# (tests expect endpoints to be unauthenticated by default)
|
||||
os.environ["ANTHROPIC_AUTH_TOKEN"] = ""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from config.nim import NimSettings
|
||||
from messaging.models import IncomingMessage
|
||||
from messaging.platforms.base import (
|
||||
CLISession,
|
||||
MessagingPlatform,
|
||||
SessionManagerInterface,
|
||||
)
|
||||
from messaging.session import SessionStore
|
||||
from providers.base import ProviderConfig
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
Settings.model_config = {**Settings.model_config, "env_file": None}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_from_dotenv(monkeypatch):
|
||||
"""Prevent Pydantic BaseSettings from reading the .env file during tests."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setattr(
|
||||
Settings, "model_config", {**Settings.model_config, "env_file": None}
|
||||
)
|
||||
|
|
@ -40,6 +30,8 @@ def _isolate_from_dotenv(monkeypatch):
|
|||
|
||||
@pytest.fixture
|
||||
def provider_config():
|
||||
from providers.base import ProviderConfig
|
||||
|
||||
return ProviderConfig(
|
||||
api_key="test_key",
|
||||
base_url="https://test.api.nvidia.com/v1",
|
||||
|
|
@ -50,6 +42,9 @@ def provider_config():
|
|||
|
||||
@pytest.fixture
|
||||
def nim_provider(provider_config):
|
||||
from config.nim import NimSettings
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
return NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
|
||||
|
||||
|
|
@ -62,6 +57,7 @@ def open_router_provider(provider_config):
|
|||
|
||||
@pytest.fixture
|
||||
def lmstudio_provider(provider_config):
|
||||
from providers.base import ProviderConfig
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
|
||||
lmstudio_config = ProviderConfig(
|
||||
|
|
@ -75,6 +71,7 @@ def lmstudio_provider(provider_config):
|
|||
|
||||
@pytest.fixture
|
||||
def llamacpp_provider(provider_config):
|
||||
from providers.base import ProviderConfig
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
|
||||
llamacpp_config = ProviderConfig(
|
||||
|
|
@ -88,6 +85,8 @@ def llamacpp_provider(provider_config):
|
|||
|
||||
@pytest.fixture
|
||||
def mock_cli_session():
|
||||
from messaging.platforms.base import CLISession
|
||||
|
||||
session = MagicMock(spec=CLISession)
|
||||
session.start_task = MagicMock() # This will return an async generator
|
||||
session.is_busy = False
|
||||
|
|
@ -96,6 +95,8 @@ def mock_cli_session():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_cli_manager():
|
||||
from messaging.platforms.base import SessionManagerInterface
|
||||
|
||||
manager = MagicMock(spec=SessionManagerInterface)
|
||||
manager.get_or_create_session = AsyncMock()
|
||||
manager.register_real_session_id = AsyncMock(return_value=True)
|
||||
|
|
@ -107,6 +108,8 @@ def mock_cli_manager():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_platform():
|
||||
from messaging.platforms.base import MessagingPlatform
|
||||
|
||||
platform = MagicMock(spec=MessagingPlatform)
|
||||
platform.send_message = AsyncMock(return_value="msg_123")
|
||||
platform.edit_message = AsyncMock()
|
||||
|
|
@ -127,6 +130,8 @@ def mock_platform():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_session_store():
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = MagicMock(spec=SessionStore)
|
||||
store.save_tree = MagicMock()
|
||||
store.get_tree = MagicMock(return_value=None)
|
||||
|
|
@ -156,6 +161,8 @@ def incoming_message_factory():
|
|||
)
|
||||
|
||||
def _create(**kwargs):
|
||||
from messaging.models import IncomingMessage
|
||||
|
||||
defaults: dict[str, Any] = {
|
||||
"text": "hello",
|
||||
"chat_id": "chat_1",
|
||||
|
|
|
|||
|
|
@ -199,6 +199,24 @@ def test_convert_assistant_message_thinking_include_reasoning_for_openrouter():
|
|||
assert "<think>" in result[0]["content"]
|
||||
|
||||
|
||||
def test_convert_assistant_message_thinking_removed_when_disabled():
|
||||
content = [
|
||||
MockBlock(type="thinking", thinking="I need to calculate this."),
|
||||
MockBlock(type="text", text="The answer is 4."),
|
||||
]
|
||||
messages = [MockMessage("assistant", content)]
|
||||
result = AnthropicToOpenAIConverter.convert_messages(
|
||||
messages,
|
||||
include_thinking=False,
|
||||
include_reasoning_for_openrouter=True,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "reasoning_content" not in result[0]
|
||||
assert "<think>" not in result[0]["content"]
|
||||
assert result[0]["content"] == "The answer is 4."
|
||||
|
||||
|
||||
def test_convert_assistant_message_tool_use():
|
||||
content = [
|
||||
MockBlock(type="text", text="I will call the tool."),
|
||||
|
|
|
|||
|
|
@ -110,6 +110,37 @@ def test_init_base_url_strips_trailing_slash():
|
|||
assert provider._base_url == "http://localhost:8080/v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_omits_thinking_when_globally_disabled(llamacpp_config):
|
||||
provider = LlamaCppProvider(
|
||||
llamacpp_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def empty_aiter():
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
mock_response.aiter_lines = empty_aiter
|
||||
|
||||
with (
|
||||
patch.object(provider._client, "build_request") as mock_build,
|
||||
patch.object(
|
||||
provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
[e async for e in provider.stream_response(req)]
|
||||
|
||||
_, kwargs = mock_build.call_args
|
||||
assert "thinking" not in kwargs["json"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response(llamacpp_provider):
|
||||
"""Test streaming native Anthropic response."""
|
||||
|
|
@ -254,3 +285,38 @@ async def test_stream_network_error(llamacpp_provider):
|
|||
assert events[0].startswith("event: error\ndata: {")
|
||||
assert "Connection refused" in events[0]
|
||||
assert "TEST_ID2" in events[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_error_405_mentions_upstream_provider(llamacpp_provider):
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_response.aread = AsyncMock(return_value=b"Method Not Allowed")
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Method Not Allowed", request=MagicMock(), response=mock_response
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
llamacpp_provider._client, "build_request", return_value=MagicMock()
|
||||
),
|
||||
patch.object(
|
||||
llamacpp_provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
events = [
|
||||
e async for e in llamacpp_provider.stream_response(req, request_id="REQ405")
|
||||
]
|
||||
|
||||
assert (
|
||||
"Upstream provider LLAMACPP rejected the request method or endpoint (HTTP 405)."
|
||||
in events[0]
|
||||
)
|
||||
assert "REQ405" in events[0]
|
||||
|
|
|
|||
|
|
@ -110,6 +110,68 @@ def test_init_base_url_strips_trailing_slash():
|
|||
assert provider._base_url == "http://localhost:1234/v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_omits_thinking_when_globally_disabled(lmstudio_config):
|
||||
provider = LMStudioProvider(
|
||||
lmstudio_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def empty_aiter():
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
mock_response.aiter_lines = empty_aiter
|
||||
|
||||
with (
|
||||
patch.object(provider._client, "build_request") as mock_build,
|
||||
patch.object(
|
||||
provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
[e async for e in provider.stream_response(req)]
|
||||
|
||||
_, kwargs = mock_build.call_args
|
||||
assert "thinking" not in kwargs["json"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_omits_thinking_when_request_disables_it(
|
||||
lmstudio_provider,
|
||||
):
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
async def empty_aiter():
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
mock_response.aiter_lines = empty_aiter
|
||||
|
||||
with (
|
||||
patch.object(lmstudio_provider._client, "build_request") as mock_build,
|
||||
patch.object(
|
||||
lmstudio_provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
[e async for e in lmstudio_provider.stream_response(req)]
|
||||
|
||||
_, kwargs = mock_build.call_args
|
||||
assert "thinking" not in kwargs["json"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response(lmstudio_provider):
|
||||
"""Test streaming native Anthropic response."""
|
||||
|
|
@ -254,3 +316,38 @@ async def test_stream_network_error(lmstudio_provider):
|
|||
assert events[0].startswith("event: error\ndata: {")
|
||||
assert "Connection refused" in events[0]
|
||||
assert "TEST_ID2" in events[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider):
|
||||
req = MockRequest()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_response.aread = AsyncMock(return_value=b"Method Not Allowed")
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Method Not Allowed", request=MagicMock(), response=mock_response
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
lmstudio_provider._client, "build_request", return_value=MagicMock()
|
||||
),
|
||||
patch.object(
|
||||
lmstudio_provider._client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
events = [
|
||||
e async for e in lmstudio_provider.stream_response(req, request_id="REQ405")
|
||||
]
|
||||
|
||||
assert (
|
||||
"Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)."
|
||||
in events[0]
|
||||
)
|
||||
assert "REQ405" in events[0]
|
||||
|
|
|
|||
|
|
@ -91,9 +91,7 @@ async def test_build_request_body(provider_config):
|
|||
"""Test request body construction."""
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config, nim_settings=NimSettings(enable_thinking=True)
|
||||
)
|
||||
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
|
|
@ -110,6 +108,40 @@ async def test_build_request_body(provider_config):
|
|||
assert body["extra_body"]["reasoning_budget"] == body["max_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_request_body_omits_reasoning_when_globally_disabled(
|
||||
provider_config,
|
||||
):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config.model_copy(update={"enable_thinking": False}),
|
||||
nim_settings=NimSettings(),
|
||||
)
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
extra = body.get("extra_body", {})
|
||||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_request_body_omits_reasoning_when_request_disables_thinking(
|
||||
provider_config,
|
||||
):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(provider_config, nim_settings=NimSettings())
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
extra = body.get("extra_body", {})
|
||||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_text(nim_provider):
|
||||
"""Test streaming text response."""
|
||||
|
|
@ -195,6 +227,44 @@ async def test_stream_response_thinking_reasoning_content(nim_provider):
|
|||
assert found_thinking
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_suppresses_thinking_when_disabled(provider_config):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config.model_copy(update={"enable_thinking": False}),
|
||||
nim_settings=NimSettings(),
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content="<think>secret</think>Answer", reasoning_content="Thinking..."
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = None
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [e async for e in provider.stream_response(req)]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking_delta" not in event_text
|
||||
assert "Thinking..." not in event_text
|
||||
assert "secret" not in event_text
|
||||
assert "Answer" in event_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(nim_provider):
|
||||
"""Test streaming tool calls."""
|
||||
|
|
|
|||
|
|
@ -67,21 +67,21 @@ class TestBuildRequestBody:
|
|||
def test_max_tokens_capped_by_nim(self, req):
|
||||
req.max_tokens = 100000
|
||||
nim = NimSettings(max_tokens=4096)
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
assert body["max_tokens"] == 4096
|
||||
|
||||
def test_presence_penalty_included_when_nonzero(self, req):
|
||||
nim = NimSettings(presence_penalty=0.5)
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
assert body["presence_penalty"] == 0.5
|
||||
|
||||
def test_include_stop_str_in_output_not_sent(self, req):
|
||||
body = build_request_body(req, NimSettings())
|
||||
body = build_request_body(req, NimSettings(), thinking_enabled=True)
|
||||
assert "include_stop_str_in_output" not in body.get("extra_body", {})
|
||||
|
||||
def test_parallel_tool_calls_included(self, req):
|
||||
nim = NimSettings(parallel_tool_calls=False)
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
assert body["parallel_tool_calls"] is False
|
||||
|
||||
def test_reasoning_params_in_extra_body(self):
|
||||
|
|
@ -98,8 +98,8 @@ class TestBuildRequestBody:
|
|||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
nim = NimSettings(enable_thinking=True)
|
||||
body = build_request_body(req, nim)
|
||||
nim = NimSettings()
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
extra = body["extra_body"]
|
||||
assert extra["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
|
|
@ -121,8 +121,8 @@ class TestBuildRequestBody:
|
|||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
nim = NimSettings(enable_thinking=False)
|
||||
body = build_request_body(req, nim)
|
||||
nim = NimSettings()
|
||||
body = build_request_body(req, nim, thinking_enabled=False)
|
||||
extra = body.get("extra_body", {})
|
||||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
|
@ -142,7 +142,7 @@ class TestBuildRequestBody:
|
|||
req.top_k = None
|
||||
|
||||
nim = NimSettings()
|
||||
body = build_request_body(req, nim)
|
||||
body = build_request_body(req, nim, thinking_enabled=False)
|
||||
extra = body.get("extra_body", {})
|
||||
for param in (
|
||||
"thinking",
|
||||
|
|
@ -152,3 +152,29 @@ class TestBuildRequestBody:
|
|||
"reasoning_effort",
|
||||
):
|
||||
assert param not in extra
|
||||
|
||||
def test_assistant_thinking_blocks_removed_when_disabled(self):
|
||||
req = MagicMock()
|
||||
req.model = "test"
|
||||
req.messages = [
|
||||
MagicMock(
|
||||
role="assistant",
|
||||
content=[
|
||||
MagicMock(type="thinking", thinking="secret"),
|
||||
MagicMock(type="text", text="answer"),
|
||||
],
|
||||
)
|
||||
]
|
||||
req.max_tokens = 100
|
||||
req.system = None
|
||||
req.temperature = None
|
||||
req.top_p = None
|
||||
req.stop_sequences = None
|
||||
req.tools = None
|
||||
req.tool_choice = None
|
||||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
body = build_request_body(req, NimSettings(), thinking_enabled=False)
|
||||
assert "<think>" not in body["messages"][0]["content"]
|
||||
assert "answer" in body["messages"][0]["content"]
|
||||
|
|
|
|||
|
|
@ -105,6 +105,26 @@ def test_build_request_body_has_reasoning_extra(open_router_provider):
|
|||
assert body["extra_body"]["reasoning"]["enabled"] is True
|
||||
|
||||
|
||||
def test_build_request_body_omits_reasoning_when_globally_disabled(open_router_config):
|
||||
provider = OpenRouterProvider(
|
||||
open_router_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
assert "extra_body" not in body or "reasoning" not in body["extra_body"]
|
||||
|
||||
|
||||
def test_build_request_body_omits_reasoning_when_request_disables_thinking(
|
||||
open_router_provider,
|
||||
):
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
body = open_router_provider._build_request_body(req)
|
||||
|
||||
assert "extra_body" not in body or "reasoning" not in body["extra_body"]
|
||||
|
||||
|
||||
def test_build_request_body_base_url_and_model(open_router_provider):
|
||||
"""Base URL and model are correct in provider config."""
|
||||
assert open_router_provider._base_url == "https://openrouter.ai/api/v1"
|
||||
|
|
@ -205,6 +225,44 @@ async def test_stream_response_reasoning_content(open_router_provider):
|
|||
assert found_thinking
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_suppresses_reasoning_when_disabled(open_router_config):
|
||||
provider = OpenRouterProvider(
|
||||
open_router_config.model_copy(update={"enable_thinking": False})
|
||||
)
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content="<think>secret</think>Answer",
|
||||
reasoning_content="Thinking...",
|
||||
reasoning_details=[{"text": "Step 1"}],
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = None
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [e async for e in provider.stream_response(req)]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking_delta" not in event_text
|
||||
assert "Thinking..." not in event_text
|
||||
assert "Step 1" not in event_text
|
||||
assert "secret" not in event_text
|
||||
assert "Answer" in event_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_empty_choices_skipped(open_router_provider):
|
||||
"""Chunks with empty choices are skipped."""
|
||||
|
|
|
|||
|
|
@ -39,6 +39,18 @@ def _make_provider():
|
|||
return NvidiaNimProvider(config, nim_settings=NimSettings())
|
||||
|
||||
|
||||
def _make_provider_with_thinking_enabled(enabled: bool):
|
||||
"""Create a provider instance with thinking explicitly enabled or disabled."""
|
||||
config = ProviderConfig(
|
||||
api_key="test_key",
|
||||
base_url="https://test.api.nvidia.com/v1",
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
enable_thinking=enabled,
|
||||
)
|
||||
return NvidiaNimProvider(config, nim_settings=NimSettings())
|
||||
|
||||
|
||||
def _make_request(model="test-model", stream=True):
|
||||
"""Create a mock request with all fields build_request_body needs."""
|
||||
req = MagicMock()
|
||||
|
|
@ -272,6 +284,76 @@ class TestStreamingExceptionHandling:
|
|||
assert "I think..." in event_text
|
||||
assert "The answer" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_reasoning_content_suppressed_when_disabled(self):
|
||||
"""reasoning deltas are stripped while normal text still streams."""
|
||||
provider = _make_provider_with_thinking_enabled(False)
|
||||
request = _make_request()
|
||||
|
||||
chunk1 = _make_chunk(reasoning_content="I think...")
|
||||
chunk2 = _make_chunk(content="<think>secret</think>The answer")
|
||||
chunk3 = _make_chunk(finish_reason="stop")
|
||||
stream_mock = AsyncStreamMock([chunk1, chunk2, chunk3])
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
),
|
||||
patch.object(
|
||||
provider._global_rate_limiter,
|
||||
"wait_if_blocked",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
events = await _collect_stream(provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking_delta" not in event_text
|
||||
assert "I think..." not in event_text
|
||||
assert "secret" not in event_text
|
||||
assert "The answer" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_upstream_405_mentions_provider_name(self):
|
||||
"""HTTP 405s are surfaced as upstream method/endpoint rejections."""
|
||||
provider = _make_provider()
|
||||
request = _make_request()
|
||||
|
||||
response = httpx.Response(
|
||||
status_code=405,
|
||||
request=httpx.Request("POST", "https://example.com/v1/chat/completions"),
|
||||
)
|
||||
error = httpx.HTTPStatusError(
|
||||
"Method Not Allowed",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=error,
|
||||
):
|
||||
events = [
|
||||
e
|
||||
async for e in provider.stream_response(
|
||||
request,
|
||||
request_id="REQ405",
|
||||
)
|
||||
]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert (
|
||||
"Upstream provider NIM rejected the request method or endpoint (HTTP 405)."
|
||||
in event_text
|
||||
)
|
||||
assert "request_id=REQ405" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_rate_limited_retries_via_execute_with_retry(self):
|
||||
"""When rate limited, execute_with_retry handles retries transparently."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue