diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index 6adba75..78acc7a 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -10,7 +10,11 @@ from config.nim import NimSettings from providers.base import ProviderConfig from providers.openai_compat import OpenAICompatibleProvider -from .request import build_request_body, clone_body_without_reasoning_budget +from .request import ( + build_request_body, + clone_body_without_chat_template, + clone_body_without_reasoning_budget, +) NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1" @@ -36,7 +40,7 @@ class NvidiaNimProvider(OpenAICompatibleProvider): ) def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None: - """Retry once without reasoning_budget when NIM rejects that field.""" + """Retry once with a downgraded body when NIM rejects a known field.""" status_code = getattr(error, "status_code", None) if not isinstance(error, openai.BadRequestError) and status_code != 400: return None @@ -45,12 +49,22 @@ class NvidiaNimProvider(OpenAICompatibleProvider): error_body = getattr(error, "body", None) if error_body is not None: error_text = f"{error_text} {json.dumps(error_body, default=str)}" - if "reasoning_budget" not in error_text.lower(): - return None + error_text = error_text.lower() - retry_body = clone_body_without_reasoning_budget(body) - if retry_body is None: - return None + if "reasoning_budget" in error_text: + retry_body = clone_body_without_reasoning_budget(body) + if retry_body is None: + return None + logger.warning( + "NIM_STREAM: retrying without reasoning_budget after 400 error" + ) + return retry_body - logger.warning("NIM_STREAM: retrying without reasoning_budget after 400 error") - return retry_body + if "chat_template" in error_text: + retry_body = clone_body_without_chat_template(body) + if retry_body is None: + return None + logger.warning("NIM_STREAM: retrying without chat_template after 400 error") + return retry_body + + return None diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index 41d480e..0c2f26d 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -47,6 +47,22 @@ def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] return cloned_body +def clone_body_without_chat_template(body: dict[str, Any]) -> dict[str, Any] | None: + """Clone a request body and strip only chat_template.""" + cloned_body = deepcopy(body) + extra_body = cloned_body.get("extra_body") + if not isinstance(extra_body, dict): + return None + + if extra_body.pop("chat_template", None) is None: + return None + + if not extra_body: + cloned_body.pop("extra_body", None) + + return cloned_body + + def build_request_body( request_data: Any, nim: NimSettings, *, thinking_enabled: bool ) -> dict: diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py index d8875ab..41d1e38 100644 --- a/tests/providers/test_nvidia_nim.py +++ b/tests/providers/test_nvidia_nim.py @@ -39,6 +39,15 @@ class MockRequest: setattr(self, k, v) +def _make_bad_request_error(message: str) -> openai.BadRequestError: + response = Response( + status_code=400, + request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"), + ) + body = {"error": {"message": message, "type": "BadRequestError", "code": 400}} + return openai.BadRequestError(message, response=response, body=body) + + @pytest.fixture(autouse=True) def mock_rate_limiter(): """Mock the global rate limiter to prevent waiting.""" @@ -268,10 +277,86 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config assert "Answer" in event_text -def _make_bad_request_error(message: str) -> openai.BadRequestError: - response = Response(status_code=400, request=Request("POST", "http://test")) - body = {"error": {"message": message}} - return openai.BadRequestError(message, response=response, body=body) +@pytest.mark.asyncio +async def test_stream_response_retries_without_chat_template(provider_config): + from config.nim import NimSettings + + provider = NvidiaNimProvider( + provider_config, + nim_settings=NimSettings(chat_template="custom_template"), + ) + req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1") + + mock_chunk = MagicMock() + mock_chunk.choices = [ + MagicMock( + delta=MagicMock(content="OK", reasoning_content=""), + finish_reason="stop", + ) + ] + mock_chunk.usage = MagicMock(completion_tokens=2) + + async def mock_stream(): + yield mock_chunk + + first_error = _make_bad_request_error( + "chat_template is not supported for Mistral tokenizers." + ) + + with patch.object( + provider._client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.side_effect = [first_error, mock_stream()] + + events = [e async for e in provider.stream_response(req)] + + assert mock_create.await_count == 2 + + first_extra = mock_create.call_args_list[0].kwargs["extra_body"] + second_extra = mock_create.call_args_list[1].kwargs["extra_body"] + + assert first_extra["chat_template"] == "custom_template" + assert first_extra["chat_template_kwargs"] == { + "thinking": True, + "enable_thinking": True, + "reasoning_budget": 100, + } + assert "reasoning_budget" not in first_extra + + assert "chat_template" not in second_extra + assert second_extra["chat_template_kwargs"] == { + "thinking": True, + "enable_thinking": True, + "reasoning_budget": 100, + } + assert "reasoning_budget" not in second_extra + + event_text = "".join(events) + assert "event: error" not in event_text + assert "OK" in event_text + + +@pytest.mark.asyncio +async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config): + from config.nim import NimSettings + + provider = NvidiaNimProvider( + provider_config, + nim_settings=NimSettings(chat_template="custom_template"), + ) + req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1") + + with patch.object( + provider._client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.side_effect = _make_bad_request_error("unrelated bad request") + + events = [e async for e in provider.stream_response(req)] + + assert mock_create.await_count == 1 + event_text = "".join(events) + assert "unrelated bad request" in event_text + assert "event: message_stop" in event_text @pytest.mark.asyncio diff --git a/tests/providers/test_nvidia_nim_request.py b/tests/providers/test_nvidia_nim_request.py index 2dd040f..e4b9953 100644 --- a/tests/providers/test_nvidia_nim_request.py +++ b/tests/providers/test_nvidia_nim_request.py @@ -6,7 +6,11 @@ import pytest from config.nim import NimSettings from providers.common.utils import set_if_not_none -from providers.nvidia_nim.request import _set_extra, build_request_body +from providers.nvidia_nim.request import ( + _set_extra, + build_request_body, + clone_body_without_chat_template, +) @pytest.fixture @@ -105,6 +109,32 @@ class TestBuildRequestBody: } assert "reasoning_budget" not in extra + def test_clone_body_without_chat_template(self): + body = { + "model": "test", + "extra_body": { + "chat_template": "custom_template", + "chat_template_kwargs": { + "thinking": True, + "enable_thinking": True, + "reasoning_budget": 100, + }, + "ignore_eos": False, + }, + } + + cloned = clone_body_without_chat_template(body) + + assert cloned is not None + assert "chat_template" not in cloned["extra_body"] + assert cloned["extra_body"]["chat_template_kwargs"] == { + "thinking": True, + "enable_thinking": True, + "reasoning_budget": 100, + } + assert cloned["extra_body"]["ignore_eos"] is False + assert body["extra_body"]["chat_template"] == "custom_template" + def test_no_chat_template_kwargs_when_thinking_disabled(self): req = MagicMock() req.model = "test" @@ -148,6 +178,30 @@ class TestBuildRequestBody: "reasoning_budget": body["max_tokens"], } + def test_chat_template_fields_present_for_mistral_model(self): + req = MagicMock() + req.model = "mistralai/mixtral-8x7b-instruct-v0.1" + req.messages = [MagicMock(role="user", content="hi")] + req.max_tokens = 100 + req.system = None + req.temperature = None + req.top_p = None + req.stop_sequences = None + req.tools = None + req.tool_choice = None + req.extra_body = None + req.top_k = None + + nim = NimSettings(chat_template="custom_template") + body = build_request_body(req, nim, thinking_enabled=True) + extra = body.get("extra_body", {}) + assert extra["chat_template_kwargs"] == { + "thinking": True, + "enable_thinking": True, + "reasoning_budget": body["max_tokens"], + } + assert extra["chat_template"] == "custom_template" + def test_no_reasoning_params_in_extra_body(self): req = MagicMock() req.model = "test"