Per claude model mapping (#66)

This commit is contained in:
Ali Khokhar 2026-03-01 21:32:23 -08:00 committed by GitHub
parent 763c8b62b7
commit 0b324e0421
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 454 additions and 81 deletions

View file

@ -106,8 +106,6 @@ class TestSettings:
# --- NimSettings Validation Tests ---
class TestNimSettingsValidBounds:
"""Test that valid values within bounds are accepted."""
@ -310,3 +308,151 @@ class TestSettingsOptionalStr:
monkeypatch.setenv("WHISPER_DEVICE", device)
s = Settings()
assert s.whisper_device == device
class TestPerTierModelMapping:
"""Test per-tier model fields and resolve_model()."""
def test_tier_fields_default_none(self):
"""Per-tier model fields default to None."""
from config.settings import Settings
s = Settings()
assert s.model_opus is None
assert s.model_sonnet is None
assert s.model_haiku is None
def test_model_opus_from_env(self, monkeypatch):
"""MODEL_OPUS env var is loaded."""
from config.settings import Settings
monkeypatch.setenv("MODEL_OPUS", "open_router/deepseek/deepseek-r1")
s = Settings()
assert s.model_opus == "open_router/deepseek/deepseek-r1"
def test_model_sonnet_from_env(self, monkeypatch):
"""MODEL_SONNET env var is loaded."""
from config.settings import Settings
monkeypatch.setenv("MODEL_SONNET", "nvidia_nim/meta/llama-3.3-70b-instruct")
s = Settings()
assert s.model_sonnet == "nvidia_nim/meta/llama-3.3-70b-instruct"
def test_model_haiku_from_env(self, monkeypatch):
"""MODEL_HAIKU env var is loaded."""
from config.settings import Settings
monkeypatch.setenv("MODEL_HAIKU", "lmstudio/qwen2.5-7b")
s = Settings()
assert s.model_haiku == "lmstudio/qwen2.5-7b"
def test_model_opus_invalid_provider_raises(self, monkeypatch):
"""MODEL_OPUS with invalid provider prefix raises ValidationError."""
from config.settings import Settings
monkeypatch.setenv("MODEL_OPUS", "bad_provider/some-model")
with pytest.raises(ValidationError, match="Invalid provider"):
Settings()
def test_model_opus_no_slash_raises(self, monkeypatch):
"""MODEL_OPUS without provider prefix raises ValidationError."""
from config.settings import Settings
monkeypatch.setenv("MODEL_OPUS", "noprefix")
with pytest.raises(ValidationError, match="provider type"):
Settings()
def test_model_haiku_invalid_provider_raises(self, monkeypatch):
"""MODEL_HAIKU with invalid provider prefix raises ValidationError."""
from config.settings import Settings
monkeypatch.setenv("MODEL_HAIKU", "invalid/model")
with pytest.raises(ValidationError, match="Invalid provider"):
Settings()
def test_resolve_model_opus_override(self):
"""resolve_model returns model_opus for opus model names."""
from config.settings import Settings
s = Settings()
s.model_opus = "open_router/deepseek/deepseek-r1"
assert (
s.resolve_model("claude-opus-4-20250514")
== "open_router/deepseek/deepseek-r1"
)
assert s.resolve_model("claude-3-opus") == "open_router/deepseek/deepseek-r1"
assert (
s.resolve_model("claude-3-opus-20240229")
== "open_router/deepseek/deepseek-r1"
)
def test_resolve_model_sonnet_override(self):
"""resolve_model returns model_sonnet for sonnet model names."""
from config.settings import Settings
s = Settings()
s.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct"
assert (
s.resolve_model("claude-sonnet-4-20250514")
== "nvidia_nim/meta/llama-3.3-70b-instruct"
)
assert (
s.resolve_model("claude-3-5-sonnet-20241022")
== "nvidia_nim/meta/llama-3.3-70b-instruct"
)
def test_resolve_model_haiku_override(self):
"""resolve_model returns model_haiku for haiku model names."""
from config.settings import Settings
s = Settings()
s.model_haiku = "lmstudio/qwen2.5-7b"
assert s.resolve_model("claude-3-haiku-20240307") == "lmstudio/qwen2.5-7b"
assert s.resolve_model("claude-3-5-haiku-20241022") == "lmstudio/qwen2.5-7b"
assert s.resolve_model("claude-haiku-4-20250514") == "lmstudio/qwen2.5-7b"
def test_resolve_model_fallback_when_tier_not_set(self):
"""resolve_model falls back to MODEL when tier override is None."""
from config.settings import Settings
s = Settings()
s.model = "nvidia_nim/fallback-model"
# No tier overrides set
assert s.resolve_model("claude-opus-4-20250514") == "nvidia_nim/fallback-model"
assert (
s.resolve_model("claude-sonnet-4-20250514") == "nvidia_nim/fallback-model"
)
assert s.resolve_model("claude-3-haiku-20240307") == "nvidia_nim/fallback-model"
def test_resolve_model_unknown_tier_falls_back(self):
"""resolve_model falls back to MODEL for unrecognized model names."""
from config.settings import Settings
s = Settings()
s.model = "nvidia_nim/fallback-model"
s.model_opus = "open_router/opus-model"
assert s.resolve_model("claude-2.1") == "nvidia_nim/fallback-model"
assert s.resolve_model("some-unknown-model") == "nvidia_nim/fallback-model"
def test_resolve_model_case_insensitive(self):
"""Tier classification is case-insensitive."""
from config.settings import Settings
s = Settings()
s.model_opus = "open_router/opus-model"
assert s.resolve_model("Claude-OPUS-4") == "open_router/opus-model"
def test_parse_provider_type(self):
"""parse_provider_type extracts provider from model string."""
from config.settings import Settings
assert Settings.parse_provider_type("nvidia_nim/meta/llama") == "nvidia_nim"
assert Settings.parse_provider_type("open_router/deepseek/r1") == "open_router"
assert Settings.parse_provider_type("lmstudio/qwen") == "lmstudio"
def test_parse_model_name(self):
"""parse_model_name extracts model name from model string."""
from config.settings import Settings
assert Settings.parse_model_name("nvidia_nim/meta/llama") == "meta/llama"
assert Settings.parse_model_name("lmstudio/qwen") == "qwen"