Fix: Exclude chat_template for Mistral tokenizers in NVIDIA NIM (#130) (#131)

Fixes #130. This PR updates the NVIDIA NIM provider to omit
\chat_template_kwargs\ and \chat_template\ when using a Mistral
tokenizer model. This resolves the 400 Bad Request error returned by the
API.

Co-authored-by: Alishahryar1 <alishahryar2@gmail.com>
This commit is contained in:
Anuj Nitin Bharambe 2026-04-23 05:46:45 +05:30 committed by GitHub
parent 4afca05318
commit 4fdf7e8b7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 183 additions and 14 deletions

View file

@ -39,6 +39,15 @@ class MockRequest:
setattr(self, k, v)
def _make_bad_request_error(message: str) -> openai.BadRequestError:
response = Response(
status_code=400,
request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"),
)
body = {"error": {"message": message, "type": "BadRequestError", "code": 400}}
return openai.BadRequestError(message, response=response, body=body)
@pytest.fixture(autouse=True)
def mock_rate_limiter():
"""Mock the global rate limiter to prevent waiting."""
@ -268,10 +277,86 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
assert "Answer" in event_text
def _make_bad_request_error(message: str) -> openai.BadRequestError:
response = Response(status_code=400, request=Request("POST", "http://test"))
body = {"error": {"message": message}}
return openai.BadRequestError(message, response=response, body=body)
@pytest.mark.asyncio
async def test_stream_response_retries_without_chat_template(provider_config):
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config,
nim_settings=NimSettings(chat_template="custom_template"),
)
req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1")
mock_chunk = MagicMock()
mock_chunk.choices = [
MagicMock(
delta=MagicMock(content="OK", reasoning_content=""),
finish_reason="stop",
)
]
mock_chunk.usage = MagicMock(completion_tokens=2)
async def mock_stream():
yield mock_chunk
first_error = _make_bad_request_error(
"chat_template is not supported for Mistral tokenizers."
)
with patch.object(
provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.side_effect = [first_error, mock_stream()]
events = [e async for e in provider.stream_response(req)]
assert mock_create.await_count == 2
first_extra = mock_create.call_args_list[0].kwargs["extra_body"]
second_extra = mock_create.call_args_list[1].kwargs["extra_body"]
assert first_extra["chat_template"] == "custom_template"
assert first_extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
}
assert "reasoning_budget" not in first_extra
assert "chat_template" not in second_extra
assert second_extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
}
assert "reasoning_budget" not in second_extra
event_text = "".join(events)
assert "event: error" not in event_text
assert "OK" in event_text
@pytest.mark.asyncio
async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config):
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config,
nim_settings=NimSettings(chat_template="custom_template"),
)
req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1")
with patch.object(
provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.side_effect = _make_bad_request_error("unrelated bad request")
events = [e async for e in provider.stream_response(req)]
assert mock_create.await_count == 1
event_text = "".join(events)
assert "unrelated bad request" in event_text
assert "event: message_stop" in event_text
@pytest.mark.asyncio

View file

@ -6,7 +6,11 @@ import pytest
from config.nim import NimSettings
from providers.common.utils import set_if_not_none
from providers.nvidia_nim.request import _set_extra, build_request_body
from providers.nvidia_nim.request import (
_set_extra,
build_request_body,
clone_body_without_chat_template,
)
@pytest.fixture
@ -105,6 +109,32 @@ class TestBuildRequestBody:
}
assert "reasoning_budget" not in extra
def test_clone_body_without_chat_template(self):
body = {
"model": "test",
"extra_body": {
"chat_template": "custom_template",
"chat_template_kwargs": {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
},
"ignore_eos": False,
},
}
cloned = clone_body_without_chat_template(body)
assert cloned is not None
assert "chat_template" not in cloned["extra_body"]
assert cloned["extra_body"]["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
}
assert cloned["extra_body"]["ignore_eos"] is False
assert body["extra_body"]["chat_template"] == "custom_template"
def test_no_chat_template_kwargs_when_thinking_disabled(self):
req = MagicMock()
req.model = "test"
@ -148,6 +178,30 @@ class TestBuildRequestBody:
"reasoning_budget": body["max_tokens"],
}
def test_chat_template_fields_present_for_mistral_model(self):
req = MagicMock()
req.model = "mistralai/mixtral-8x7b-instruct-v0.1"
req.messages = [MagicMock(role="user", content="hi")]
req.max_tokens = 100
req.system = None
req.temperature = None
req.top_p = None
req.stop_sequences = None
req.tools = None
req.tool_choice = None
req.extra_body = None
req.top_k = None
nim = NimSettings(chat_template="custom_template")
body = build_request_body(req, nim, thinking_enabled=True)
extra = body.get("extra_body", {})
assert extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": body["max_tokens"],
}
assert extra["chat_template"] == "custom_template"
def test_no_reasoning_params_in_extra_body(self):
req = MagicMock()
req.model = "test"