Add per-model thinking toggles

This commit is contained in:
Alishahryar1 2026-04-24 00:14:49 -07:00
parent 462a9430bb
commit 1f12a33dd7
14 changed files with 271 additions and 88 deletions

View file

@ -32,7 +32,10 @@ def _make_mock_settings(**overrides):
mock.http_read_timeout = 300.0
mock.http_write_timeout = 10.0
mock.http_connect_timeout = 2.0
mock.enable_thinking = True
mock.opus_enable_thinking = True
mock.sonnet_enable_thinking = True
mock.haiku_enable_thinking = True
mock.model_enable_thinking = True
for key, value in overrides.items():
setattr(mock, key, value)
return mock
@ -134,7 +137,7 @@ async def test_get_provider_deepseek():
assert isinstance(provider, DeepSeekProvider)
assert provider._base_url == "https://api.deepseek.com"
assert provider._api_key == "test_deepseek_key"
assert provider._config.enable_thinking is True
assert provider._config.model_enable_thinking is True
@pytest.mark.asyncio
@ -152,18 +155,38 @@ async def test_get_provider_deepseek_uses_fixed_base_url():
@pytest.mark.asyncio
async def test_get_provider_deepseek_passes_enable_thinking():
"""DeepSeek provider receives the global thinking toggle."""
async def test_get_provider_deepseek_passes_model_enable_thinking():
"""DeepSeek provider receives the fallback thinking toggle."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings(
provider_type="deepseek",
enable_thinking=False,
model_enable_thinking=False,
)
provider = get_provider()
assert isinstance(provider, DeepSeekProvider)
assert provider._config.enable_thinking is False
assert provider._config.model_enable_thinking is False
@pytest.mark.asyncio
async def test_get_provider_passes_per_model_thinking_flags():
"""Provider config receives every per-model thinking toggle."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings(
opus_enable_thinking=False,
sonnet_enable_thinking=True,
haiku_enable_thinking=False,
model_enable_thinking=True,
)
provider = get_provider()
assert isinstance(provider, NvidiaNimProvider)
assert provider._config.opus_enable_thinking is False
assert provider._config.sonnet_enable_thinking is True
assert provider._config.haiku_enable_thinking is False
assert provider._config.model_enable_thinking is True
@pytest.mark.asyncio

View file

@ -29,7 +29,10 @@ class TestSettings:
assert isinstance(settings.provider_rate_window, int)
assert isinstance(settings.nim.temperature, float)
assert isinstance(settings.fast_prefix_detection, bool)
assert isinstance(settings.enable_thinking, bool)
assert isinstance(settings.opus_enable_thinking, bool)
assert isinstance(settings.sonnet_enable_thinking, bool)
assert isinstance(settings.haiku_enable_thinking, bool)
assert isinstance(settings.model_enable_thinking, bool)
assert settings.http_read_timeout == 120.0
def test_get_settings_cached(self):
@ -110,20 +113,34 @@ class TestSettings:
settings = Settings()
assert settings.http_connect_timeout == 5.0
def test_enable_thinking_from_env(self, monkeypatch):
"""ENABLE_THINKING env var is loaded into settings."""
def test_per_model_thinking_from_env(self, monkeypatch):
"""Per-model thinking env vars are loaded into settings."""
from config.settings import Settings
monkeypatch.setenv("ENABLE_THINKING", "false")
monkeypatch.setenv("OPUS_ENABLE_THINKING", "false")
monkeypatch.setenv("SONNET_ENABLE_THINKING", "true")
monkeypatch.setenv("HAIKU_ENABLE_THINKING", "false")
monkeypatch.setenv("MODEL_ENABLE_THINKING", "true")
settings = Settings()
assert settings.enable_thinking is False
assert settings.opus_enable_thinking is False
assert settings.sonnet_enable_thinking is True
assert settings.haiku_enable_thinking is False
assert settings.model_enable_thinking is True
def test_removed_nim_enable_thinking_raises(self, monkeypatch):
"""NIM_ENABLE_THINKING now fails fast with a migration message."""
from config.settings import Settings
monkeypatch.setenv("NIM_ENABLE_THINKING", "false")
with pytest.raises(ValidationError, match="Rename it to ENABLE_THINKING"):
with pytest.raises(ValidationError, match="MODEL_ENABLE_THINKING"):
Settings()
def test_removed_enable_thinking_raises(self, monkeypatch):
"""ENABLE_THINKING now fails fast with a migration message."""
from config.settings import Settings
monkeypatch.setenv("ENABLE_THINKING", "false")
with pytest.raises(ValidationError, match="MODEL_ENABLE_THINKING"):
Settings()
@ -494,6 +511,29 @@ class TestPerModelMapping:
s.model_opus = "open_router/opus-model"
assert s.resolve_model("Claude-OPUS-4") == "open_router/opus-model"
def test_resolve_thinking_enabled_per_model_family(self):
"""resolve_thinking_enabled returns the matching model-family flag."""
from config.settings import Settings
s = Settings()
s.opus_enable_thinking = False
s.sonnet_enable_thinking = True
s.haiku_enable_thinking = False
s.model_enable_thinking = True
assert s.resolve_thinking_enabled("claude-opus-4-20250514") is False
assert s.resolve_thinking_enabled("claude-sonnet-4-20250514") is True
assert s.resolve_thinking_enabled("claude-haiku-4-20250514") is False
assert s.resolve_thinking_enabled("claude-2.1") is True
def test_resolve_thinking_enabled_case_insensitive(self):
"""Thinking model-family classification is case-insensitive."""
from config.settings import Settings
s = Settings()
s.opus_enable_thinking = False
assert s.resolve_thinking_enabled("Claude-OPUS-4") is False
def test_parse_provider_type(self):
"""parse_provider_type extracts provider from model string."""
from config.settings import Settings

View file

@ -0,0 +1,78 @@
from providers.base import BaseProvider, ProviderConfig
class DummyProvider(BaseProvider):
async def cleanup(self) -> None:
return None
async def stream_response(self, request, input_tokens=0, *, request_id=None):
if False:
yield ""
class DummyThinking:
def __init__(self, enabled: bool):
self.enabled = enabled
class DummyRequest:
def __init__(
self,
*,
model: str,
original_model: str | None = None,
thinking: DummyThinking | None = None,
):
self.model = model
self.original_model = original_model
self.thinking = thinking
def test_is_thinking_enabled_uses_original_model_family():
provider = DummyProvider(
ProviderConfig(
api_key="test",
opus_enable_thinking=False,
model_enable_thinking=True,
)
)
request = DummyRequest(
model="provider-model",
original_model="claude-opus-4-20250514",
thinking=DummyThinking(True),
)
assert provider._is_thinking_enabled(request) is False
def test_is_thinking_enabled_falls_back_to_request_model():
provider = DummyProvider(
ProviderConfig(
api_key="test",
sonnet_enable_thinking=False,
model_enable_thinking=True,
)
)
request = DummyRequest(model="claude-sonnet-4-20250514")
assert provider._is_thinking_enabled(request) is False
def test_is_thinking_enabled_unknown_model_uses_fallback_flag():
provider = DummyProvider(
ProviderConfig(api_key="test", model_enable_thinking=False)
)
request = DummyRequest(model="provider-model")
assert provider._is_thinking_enabled(request) is False
def test_is_thinking_enabled_respects_request_disable():
provider = DummyProvider(ProviderConfig(api_key="test", opus_enable_thinking=True))
request = DummyRequest(
model="provider-model",
original_model="claude-opus-4-20250514",
thinking=DummyThinking(False),
)
assert provider._is_thinking_enabled(request) is False

View file

@ -44,7 +44,6 @@ def deepseek_config():
base_url=DEEPSEEK_BASE_URL,
rate_limit=10,
rate_window=60,
enable_thinking=True,
)
@ -86,15 +85,15 @@ def test_build_request_body_enables_thinking_for_chat_model(deepseek_provider):
assert body["messages"][0]["role"] == "system"
def test_build_request_body_global_disable_blocks_request_thinking():
"""Global disable suppresses provider-side thinking even if the request enables it."""
def test_build_request_body_model_disable_blocks_request_thinking():
"""Model disable suppresses provider-side thinking even if the request enables it."""
provider = DeepSeekProvider(
ProviderConfig(
api_key="test_deepseek_key",
base_url=DEEPSEEK_BASE_URL,
rate_limit=10,
rate_window=60,
enable_thinking=False,
model_enable_thinking=False,
)
)
req = MockRequest(model="deepseek-chat")
@ -103,8 +102,8 @@ def test_build_request_body_global_disable_blocks_request_thinking():
assert "extra_body" not in body or "thinking" not in body["extra_body"]
def test_build_request_body_request_disable_blocks_global_thinking(deepseek_provider):
"""Request-level disable suppresses provider-side thinking when global is enabled."""
def test_build_request_body_request_disable_blocks_model_thinking(deepseek_provider):
"""Request-level disable suppresses provider-side thinking when model is enabled."""
req = MockRequest(model="deepseek-chat")
req.thinking.enabled = False
body = deepseek_provider._build_request_body(req)

View file

@ -111,9 +111,9 @@ def test_init_base_url_strips_trailing_slash():
@pytest.mark.asyncio
async def test_stream_response_omits_thinking_when_globally_disabled(llamacpp_config):
async def test_stream_response_omits_thinking_when_model_disabled(llamacpp_config):
provider = LlamaCppProvider(
llamacpp_config.model_copy(update={"enable_thinking": False})
llamacpp_config.model_copy(update={"model_enable_thinking": False})
)
req = MockRequest()

View file

@ -111,9 +111,9 @@ def test_init_base_url_strips_trailing_slash():
@pytest.mark.asyncio
async def test_stream_response_omits_thinking_when_globally_disabled(lmstudio_config):
async def test_stream_response_omits_thinking_when_model_disabled(lmstudio_config):
provider = LMStudioProvider(
lmstudio_config.model_copy(update={"enable_thinking": False})
lmstudio_config.model_copy(update={"model_enable_thinking": False})
)
req = MockRequest()

View file

@ -121,13 +121,13 @@ async def test_build_request_body(provider_config):
@pytest.mark.asyncio
async def test_build_request_body_omits_reasoning_when_globally_disabled(
async def test_build_request_body_omits_reasoning_when_model_disabled(
provider_config,
):
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config.model_copy(update={"enable_thinking": False}),
provider_config.model_copy(update={"model_enable_thinking": False}),
nim_settings=NimSettings(),
)
req = MockRequest()
@ -244,7 +244,7 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config.model_copy(update={"enable_thinking": False}),
provider_config.model_copy(update={"model_enable_thinking": False}),
nim_settings=NimSettings(),
)
req = MockRequest()

View file

@ -105,9 +105,9 @@ def test_build_request_body_has_reasoning_extra(open_router_provider):
assert body["extra_body"]["reasoning"]["enabled"] is True
def test_build_request_body_omits_reasoning_when_globally_disabled(open_router_config):
def test_build_request_body_omits_reasoning_when_model_disabled(open_router_config):
provider = OpenRouterProvider(
open_router_config.model_copy(update={"enable_thinking": False})
open_router_config.model_copy(update={"model_enable_thinking": False})
)
req = MockRequest()
body = provider._build_request_body(req)
@ -228,7 +228,7 @@ async def test_stream_response_reasoning_content(open_router_provider):
@pytest.mark.asyncio
async def test_stream_response_suppresses_reasoning_when_disabled(open_router_config):
provider = OpenRouterProvider(
open_router_config.model_copy(update={"enable_thinking": False})
open_router_config.model_copy(update={"model_enable_thinking": False})
)
req = MockRequest()

View file

@ -46,7 +46,7 @@ def _make_provider_with_thinking_enabled(enabled: bool):
base_url="https://test.api.nvidia.com/v1",
rate_limit=10,
rate_window=60,
enable_thinking=enabled,
model_enable_thinking=enabled,
)
return NvidiaNimProvider(config, nim_settings=NimSettings())