Fix: Exclude chat_template for Mistral tokenizers in NVIDIA NIM (#130) (#131)

Fixes #130. This PR updates the NVIDIA NIM provider to omit
\chat_template_kwargs\ and \chat_template\ when using a Mistral
tokenizer model. This resolves the 400 Bad Request error returned by the
API.

Co-authored-by: Alishahryar1 <alishahryar2@gmail.com>
This commit is contained in:
Anuj Nitin Bharambe 2026-04-23 05:46:45 +05:30 committed by GitHub
parent 4afca05318
commit 4fdf7e8b7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 183 additions and 14 deletions

View file

@ -10,7 +10,11 @@ from config.nim import NimSettings
from providers.base import ProviderConfig
from providers.openai_compat import OpenAICompatibleProvider
from .request import build_request_body, clone_body_without_reasoning_budget
from .request import (
build_request_body,
clone_body_without_chat_template,
clone_body_without_reasoning_budget,
)
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
@ -36,7 +40,7 @@ class NvidiaNimProvider(OpenAICompatibleProvider):
)
def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
"""Retry once without reasoning_budget when NIM rejects that field."""
"""Retry once with a downgraded body when NIM rejects a known field."""
status_code = getattr(error, "status_code", None)
if not isinstance(error, openai.BadRequestError) and status_code != 400:
return None
@ -45,12 +49,22 @@ class NvidiaNimProvider(OpenAICompatibleProvider):
error_body = getattr(error, "body", None)
if error_body is not None:
error_text = f"{error_text} {json.dumps(error_body, default=str)}"
if "reasoning_budget" not in error_text.lower():
return None
error_text = error_text.lower()
retry_body = clone_body_without_reasoning_budget(body)
if retry_body is None:
return None
if "reasoning_budget" in error_text:
retry_body = clone_body_without_reasoning_budget(body)
if retry_body is None:
return None
logger.warning(
"NIM_STREAM: retrying without reasoning_budget after 400 error"
)
return retry_body
logger.warning("NIM_STREAM: retrying without reasoning_budget after 400 error")
return retry_body
if "chat_template" in error_text:
retry_body = clone_body_without_chat_template(body)
if retry_body is None:
return None
logger.warning("NIM_STREAM: retrying without chat_template after 400 error")
return retry_body
return None

View file

@ -47,6 +47,22 @@ def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any]
return cloned_body
def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None:
"""Clone a request body and strip only chat_template."""
cloned_body = deepcopy(body)
extra_body = cloned_body.get("extra_body")
if not isinstance(extra_body, dict):
return None
if extra_body.pop("chat_template", None) is None:
return None
if not extra_body:
cloned_body.pop("extra_body", None)
return cloned_body
def build_request_body(
request_data: Any, nim: NimSettings, *, thinking_enabled: bool
) -> dict:

View file

@ -39,6 +39,15 @@ class MockRequest:
setattr(self, k, v)
def _make_bad_request_error(message: str) -> openai.BadRequestError:
response = Response(
status_code=400,
request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"),
)
body = {"error": {"message": message, "type": "BadRequestError", "code": 400}}
return openai.BadRequestError(message, response=response, body=body)
@pytest.fixture(autouse=True)
def mock_rate_limiter():
"""Mock the global rate limiter to prevent waiting."""
@ -268,10 +277,86 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
assert "Answer" in event_text
def _make_bad_request_error(message: str) -> openai.BadRequestError:
response = Response(status_code=400, request=Request("POST", "http://test"))
body = {"error": {"message": message}}
return openai.BadRequestError(message, response=response, body=body)
@pytest.mark.asyncio
async def test_stream_response_retries_without_chat_template(provider_config):
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config,
nim_settings=NimSettings(chat_template="custom_template"),
)
req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1")
mock_chunk = MagicMock()
mock_chunk.choices = [
MagicMock(
delta=MagicMock(content="OK", reasoning_content=""),
finish_reason="stop",
)
]
mock_chunk.usage = MagicMock(completion_tokens=2)
async def mock_stream():
yield mock_chunk
first_error = _make_bad_request_error(
"chat_template is not supported for Mistral tokenizers."
)
with patch.object(
provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.side_effect = [first_error, mock_stream()]
events = [e async for e in provider.stream_response(req)]
assert mock_create.await_count == 2
first_extra = mock_create.call_args_list[0].kwargs["extra_body"]
second_extra = mock_create.call_args_list[1].kwargs["extra_body"]
assert first_extra["chat_template"] == "custom_template"
assert first_extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
}
assert "reasoning_budget" not in first_extra
assert "chat_template" not in second_extra
assert second_extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
}
assert "reasoning_budget" not in second_extra
event_text = "".join(events)
assert "event: error" not in event_text
assert "OK" in event_text
@pytest.mark.asyncio
async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config):
from config.nim import NimSettings
provider = NvidiaNimProvider(
provider_config,
nim_settings=NimSettings(chat_template="custom_template"),
)
req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1")
with patch.object(
provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.side_effect = _make_bad_request_error("unrelated bad request")
events = [e async for e in provider.stream_response(req)]
assert mock_create.await_count == 1
event_text = "".join(events)
assert "unrelated bad request" in event_text
assert "event: message_stop" in event_text
@pytest.mark.asyncio

View file

@ -6,7 +6,11 @@ import pytest
from config.nim import NimSettings
from providers.common.utils import set_if_not_none
from providers.nvidia_nim.request import _set_extra, build_request_body
from providers.nvidia_nim.request import (
_set_extra,
build_request_body,
clone_body_without_chat_template,
)
@pytest.fixture
@ -105,6 +109,32 @@ class TestBuildRequestBody:
}
assert "reasoning_budget" not in extra
def test_clone_body_without_chat_template(self):
body = {
"model": "test",
"extra_body": {
"chat_template": "custom_template",
"chat_template_kwargs": {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
},
"ignore_eos": False,
},
}
cloned = clone_body_without_chat_template(body)
assert cloned is not None
assert "chat_template" not in cloned["extra_body"]
assert cloned["extra_body"]["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": 100,
}
assert cloned["extra_body"]["ignore_eos"] is False
assert body["extra_body"]["chat_template"] == "custom_template"
def test_no_chat_template_kwargs_when_thinking_disabled(self):
req = MagicMock()
req.model = "test"
@ -148,6 +178,30 @@ class TestBuildRequestBody:
"reasoning_budget": body["max_tokens"],
}
def test_chat_template_fields_present_for_mistral_model(self):
req = MagicMock()
req.model = "mistralai/mixtral-8x7b-instruct-v0.1"
req.messages = [MagicMock(role="user", content="hi")]
req.max_tokens = 100
req.system = None
req.temperature = None
req.top_p = None
req.stop_sequences = None
req.tools = None
req.tool_choice = None
req.extra_body = None
req.top_k = None
nim = NimSettings(chat_template="custom_template")
body = build_request_body(req, nim, thinking_enabled=True)
extra = body.get("extra_body", {})
assert extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": body["max_tokens"],
}
assert extra["chat_template"] == "custom_template"
def test_no_reasoning_params_in_extra_body(self):
req = MagicMock()
req.model = "test"