From 4afca05318b4657fb5c8d2b37b1bdb3ec95ebfd0 Mon Sep 17 00:00:00 2001 From: Wang Ji <48763621+Jiwangreal@users.noreply.github.com> Date: Thu, 23 Apr 2026 08:06:46 +0800 Subject: [PATCH] bug: nvidia didn't not support reasoning_budget parameter (#126) image Fixes #127. --------- Co-authored-by: u011436427 Co-authored-by: Alishahryar1 --- providers/nvidia_nim/client.py | 26 +++++++- providers/nvidia_nim/request.py | 31 +++++++++- providers/openai_compat.py | 25 +++++++- tests/providers/test_nvidia_nim.py | 69 +++++++++++++++++++++- tests/providers/test_nvidia_nim_request.py | 31 ++++++++-- 5 files changed, 170 insertions(+), 12 deletions(-) diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index 47d0300..6adba75 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -1,12 +1,16 @@ """NVIDIA NIM provider implementation.""" +import json from typing import Any +import openai +from loguru import logger + from config.nim import NimSettings from providers.base import ProviderConfig from providers.openai_compat import OpenAICompatibleProvider -from .request import build_request_body +from .request import build_request_body, clone_body_without_reasoning_budget NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1" @@ -30,3 +34,23 @@ class NvidiaNimProvider(OpenAICompatibleProvider): self._nim_settings, thinking_enabled=self._is_thinking_enabled(request), ) + + def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None: + """Retry once without reasoning_budget when NIM rejects that field.""" + status_code = getattr(error, "status_code", None) + if not isinstance(error, openai.BadRequestError) and status_code != 400: + return None + + error_text = str(error) + error_body = getattr(error, "body", None) + if error_body is not None: + error_text = f"{error_text} {json.dumps(error_body, default=str)}" + if "reasoning_budget" not in error_text.lower(): + return None + + retry_body = clone_body_without_reasoning_budget(body) + if retry_body is None: + return None + + logger.warning("NIM_STREAM: retrying without reasoning_budget after 400 error") + return retry_body diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index c82719e..41d480e 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -1,5 +1,6 @@ """Request builder for NVIDIA NIM provider.""" +from copy import deepcopy from typing import Any from loguru import logger @@ -21,6 +22,31 @@ def _set_extra( extra_body[key] = value +def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None: + """Clone a request body and strip only reasoning_budget fields.""" + cloned_body = deepcopy(body) + extra_body = cloned_body.get("extra_body") + if not isinstance(extra_body, dict): + return None + + removed = extra_body.pop("reasoning_budget", None) is not None + + chat_template_kwargs = extra_body.get("chat_template_kwargs") + if ( + isinstance(chat_template_kwargs, dict) + and chat_template_kwargs.pop("reasoning_budget", None) is not None + ): + removed = True + + if not extra_body: + cloned_body.pop("extra_body", None) + + if not removed: + return None + + return cloned_body + + def build_request_body( request_data: Any, nim: NimSettings, *, thinking_enabled: bool ) -> dict: @@ -69,10 +95,11 @@ def build_request_body( extra_body.update(request_extra) if thinking_enabled: - extra_body.setdefault( + chat_template_kwargs = extra_body.setdefault( "chat_template_kwargs", {"thinking": True, "enable_thinking": True} ) - _set_extra(extra_body, "reasoning_budget", max_tokens) + if isinstance(chat_template_kwargs, dict): + chat_template_kwargs.setdefault("reasoning_budget", max_tokens) req_top_k = getattr(request_data, "top_k", None) top_k = req_top_k if req_top_k is not None else nim.top_k diff --git a/providers/openai_compat.py b/providers/openai_compat.py index 76ff8cb..979ec3e 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -84,6 +84,27 @@ class OpenAICompatibleProvider(BaseProvider): """Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details).""" return iter(()) + def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None: + """Return a modified request body for one retry, or None.""" + return None + + async def _create_stream(self, body: dict) -> tuple[Any, dict]: + """Create a streaming chat completion, optionally retrying once.""" + try: + stream = await self._global_rate_limiter.execute_with_retry( + self._client.chat.completions.create, **body, stream=True + ) + return stream, body + except Exception as error: + retry_body = self._get_retry_request_body(error, body) + if retry_body is None: + raise + + stream = await self._global_rate_limiter.execute_with_retry( + self._client.chat.completions.create, **retry_body, stream=True + ) + return stream, retry_body + def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]: """Process a single tool call delta and yield SSE events.""" tc_index = tc.get("index", 0) @@ -174,9 +195,7 @@ class OpenAICompatibleProvider(BaseProvider): async with self._global_rate_limiter.concurrency_slot(): try: - stream = await self._global_rate_limiter.execute_with_retry( - self._client.chat.completions.create, **body, stream=True - ) + stream, body = await self._create_stream(body) async for chunk in stream: if getattr(chunk, "usage", None): usage_info = chunk.usage diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py index 3f15095..d8875ab 100644 --- a/tests/providers/test_nvidia_nim.py +++ b/tests/providers/test_nvidia_nim.py @@ -1,7 +1,9 @@ import json from unittest.mock import AsyncMock, MagicMock, patch +import openai import pytest +from httpx import Request, Response from providers.nvidia_nim import NvidiaNimProvider @@ -105,7 +107,8 @@ async def test_build_request_body(provider_config): ctk = body["extra_body"]["chat_template_kwargs"] assert ctk["thinking"] is True assert ctk["enable_thinking"] is True - assert body["extra_body"]["reasoning_budget"] == body["max_tokens"] + assert ctk["reasoning_budget"] == body["max_tokens"] + assert "reasoning_budget" not in body["extra_body"] @pytest.mark.asyncio @@ -265,6 +268,12 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config assert "Answer" in event_text +def _make_bad_request_error(message: str) -> openai.BadRequestError: + response = Response(status_code=400, request=Request("POST", "http://test")) + body = {"error": {"message": message}} + return openai.BadRequestError(message, response=response, body=body) + + @pytest.mark.asyncio async def test_tool_call_stream(nim_provider): """Test streaming tool calls.""" @@ -301,3 +310,61 @@ async def test_tool_call_stream(nim_provider): ] assert len(starts) == 1 assert "search" in starts[0] + + +@pytest.mark.asyncio +async def test_stream_response_retries_without_reasoning_budget(nim_provider): + req = MockRequest() + + mock_chunk = MagicMock() + mock_chunk.choices = [ + MagicMock( + delta=MagicMock(content="Recovered", reasoning_content=""), + finish_reason="stop", + ) + ] + mock_chunk.usage = MagicMock(completion_tokens=5) + + async def mock_stream(): + yield mock_chunk + + error = _make_bad_request_error("Unsupported field: reasoning_budget") + + with patch.object( + nim_provider._client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.side_effect = [error, mock_stream()] + + events = [e async for e in nim_provider.stream_response(req)] + + assert mock_create.await_count == 2 + first_call = mock_create.await_args_list[0].kwargs + second_call = mock_create.await_args_list[1].kwargs + assert ( + first_call["extra_body"]["chat_template_kwargs"]["reasoning_budget"] + == first_call["max_tokens"] + ) + assert "reasoning_budget" not in second_call["extra_body"] + assert "reasoning_budget" not in second_call["extra_body"]["chat_template_kwargs"] + assert second_call["extra_body"]["chat_template_kwargs"]["enable_thinking"] is True + assert any("Recovered" in event for event in events) + assert any("message_stop" in event for event in events) + + +@pytest.mark.asyncio +async def test_stream_response_bad_request_without_reasoning_budget_does_not_retry( + nim_provider, +): + req = MockRequest() + error = _make_bad_request_error("Unsupported field: top_k") + + with patch.object( + nim_provider._client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.side_effect = error + + events = [e async for e in nim_provider.stream_response(req)] + + assert mock_create.await_count == 1 + assert any("Unsupported field: top_k" in event for event in events) + assert any("message_stop" in event for event in events) diff --git a/tests/providers/test_nvidia_nim_request.py b/tests/providers/test_nvidia_nim_request.py index b7508d6..2dd040f 100644 --- a/tests/providers/test_nvidia_nim_request.py +++ b/tests/providers/test_nvidia_nim_request.py @@ -6,10 +6,7 @@ import pytest from config.nim import NimSettings from providers.common.utils import set_if_not_none -from providers.nvidia_nim.request import ( - _set_extra, - build_request_body, -) +from providers.nvidia_nim.request import _set_extra, build_request_body @pytest.fixture @@ -104,8 +101,9 @@ class TestBuildRequestBody: assert extra["chat_template_kwargs"] == { "thinking": True, "enable_thinking": True, + "reasoning_budget": body["max_tokens"], } - assert extra["reasoning_budget"] == body["max_tokens"] + assert "reasoning_budget" not in extra def test_no_chat_template_kwargs_when_thinking_disabled(self): req = MagicMock() @@ -127,6 +125,29 @@ class TestBuildRequestBody: assert "chat_template_kwargs" not in extra assert "reasoning_budget" not in extra + def test_reasoning_budget_respects_existing_chat_template_kwargs(self): + req = MagicMock() + req.model = "test" + req.messages = [MagicMock(role="user", content="hi")] + req.max_tokens = 100 + req.system = None + req.temperature = None + req.top_p = None + req.stop_sequences = None + req.tools = None + req.tool_choice = None + req.top_k = None + req.extra_body = { + "chat_template_kwargs": {"enable_thinking": False, "custom": "value"} + } + + body = build_request_body(req, NimSettings(), thinking_enabled=True) + assert body["extra_body"]["chat_template_kwargs"] == { + "enable_thinking": False, + "custom": "value", + "reasoning_budget": body["max_tokens"], + } + def test_no_reasoning_params_in_extra_body(self): req = MagicMock() req.model = "test"