free-claude-code/tests/api/test_model_router.py
2026-04-25 20:51:07 -07:00

117 lines
3.9 KiB
Python

from unittest.mock import patch
import pytest
from api.model_router import ModelRouter
from api.models.anthropic import Message, MessagesRequest, TokenCountRequest
from config.settings import Settings
@pytest.fixture
def settings():
settings = Settings()
settings.model = "nvidia_nim/fallback-model"
settings.model_opus = None
settings.model_sonnet = None
settings.model_haiku = None
settings.enable_model_thinking = True
settings.enable_opus_thinking = None
settings.enable_sonnet_thinking = None
settings.enable_haiku_thinking = None
return settings
def test_model_router_resolves_default_model(settings):
resolved = ModelRouter(settings).resolve("claude-3-opus")
assert resolved.original_model == "claude-3-opus"
assert resolved.provider_id == "nvidia_nim"
assert resolved.provider_model == "fallback-model"
assert resolved.provider_model_ref == "nvidia_nim/fallback-model"
assert resolved.thinking_enabled is True
def test_model_router_applies_opus_override(settings):
settings.model_opus = "open_router/deepseek/deepseek-r1"
request = MessagesRequest(
model="claude-opus-4-20250514",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
routed = ModelRouter(settings).resolve_messages_request(request)
assert routed.request.model == "deepseek/deepseek-r1"
assert routed.resolved.provider_model_ref == "open_router/deepseek/deepseek-r1"
assert routed.resolved.original_model == "claude-opus-4-20250514"
assert routed.resolved.thinking_enabled is True
assert request.model == "claude-opus-4-20250514"
def test_model_router_resolves_per_model_thinking(settings):
settings.enable_model_thinking = False
settings.enable_opus_thinking = True
settings.enable_haiku_thinking = False
router = ModelRouter(settings)
assert router.resolve("claude-opus-4-20250514").thinking_enabled is True
assert router.resolve("claude-sonnet-4-20250514").thinking_enabled is False
assert router.resolve("claude-3-haiku-20240307").thinking_enabled is False
assert router.resolve("claude-2.1").thinking_enabled is False
def test_model_router_applies_haiku_override(settings):
settings.model_haiku = "lmstudio/qwen2.5-7b"
routed = ModelRouter(settings).resolve_messages_request(
MessagesRequest(
model="claude-3-haiku-20240307",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
)
assert routed.request.model == "qwen2.5-7b"
assert routed.resolved.provider_model_ref == "lmstudio/qwen2.5-7b"
def test_model_router_applies_sonnet_override(settings):
settings.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct"
routed = ModelRouter(settings).resolve_messages_request(
MessagesRequest(
model="claude-sonnet-4-20250514",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
)
assert routed.request.model == "meta/llama-3.3-70b-instruct"
assert (
routed.resolved.provider_model_ref == "nvidia_nim/meta/llama-3.3-70b-instruct"
)
def test_model_router_routes_token_count_request(settings):
settings.model_haiku = "lmstudio/qwen2.5-7b"
request = TokenCountRequest(
model="claude-3-haiku-20240307",
messages=[Message(role="user", content="hello")],
)
routed = ModelRouter(settings).resolve_token_count_request(request)
assert routed.request.model == "qwen2.5-7b"
assert request.model == "claude-3-haiku-20240307"
def test_model_router_logs_mapping(settings):
with patch("api.model_router.logger.debug") as mock_log:
ModelRouter(settings).resolve("claude-2.1")
mock_log.assert_called()
args = mock_log.call_args[0]
assert "MODEL MAPPING" in args[0]
assert args[1] == "claude-2.1"
assert args[2] == "fallback-model"