From 4afca05318b4657fb5c8d2b37b1bdb3ec95ebfd0 Mon Sep 17 00:00:00 2001
From: Wang Ji <48763621+Jiwangreal@users.noreply.github.com>
Date: Thu, 23 Apr 2026 08:06:46 +0800
Subject: [PATCH] bug: nvidia didn't not support reasoning_budget parameter
(#126)
Fixes #127.
---------
Co-authored-by: u011436427
Co-authored-by: Alishahryar1
---
providers/nvidia_nim/client.py | 26 +++++++-
providers/nvidia_nim/request.py | 31 +++++++++-
providers/openai_compat.py | 25 +++++++-
tests/providers/test_nvidia_nim.py | 69 +++++++++++++++++++++-
tests/providers/test_nvidia_nim_request.py | 31 ++++++++--
5 files changed, 170 insertions(+), 12 deletions(-)
diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py
index 47d0300..6adba75 100644
--- a/providers/nvidia_nim/client.py
+++ b/providers/nvidia_nim/client.py
@@ -1,12 +1,16 @@
"""NVIDIA NIM provider implementation."""
+import json
from typing import Any
+import openai
+from loguru import logger
+
from config.nim import NimSettings
from providers.base import ProviderConfig
from providers.openai_compat import OpenAICompatibleProvider
-from .request import build_request_body
+from .request import build_request_body, clone_body_without_reasoning_budget
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
@@ -30,3 +34,23 @@ class NvidiaNimProvider(OpenAICompatibleProvider):
self._nim_settings,
thinking_enabled=self._is_thinking_enabled(request),
)
+
+ def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
+ """Retry once without reasoning_budget when NIM rejects that field."""
+ status_code = getattr(error, "status_code", None)
+ if not isinstance(error, openai.BadRequestError) and status_code != 400:
+ return None
+
+ error_text = str(error)
+ error_body = getattr(error, "body", None)
+ if error_body is not None:
+ error_text = f"{error_text} {json.dumps(error_body, default=str)}"
+ if "reasoning_budget" not in error_text.lower():
+ return None
+
+ retry_body = clone_body_without_reasoning_budget(body)
+ if retry_body is None:
+ return None
+
+ logger.warning("NIM_STREAM: retrying without reasoning_budget after 400 error")
+ return retry_body
diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py
index c82719e..41d480e 100644
--- a/providers/nvidia_nim/request.py
+++ b/providers/nvidia_nim/request.py
@@ -1,5 +1,6 @@
"""Request builder for NVIDIA NIM provider."""
+from copy import deepcopy
from typing import Any
from loguru import logger
@@ -21,6 +22,31 @@ def _set_extra(
extra_body[key] = value
+def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None:
+ """Clone a request body and strip only reasoning_budget fields."""
+ cloned_body = deepcopy(body)
+ extra_body = cloned_body.get("extra_body")
+ if not isinstance(extra_body, dict):
+ return None
+
+ removed = extra_body.pop("reasoning_budget", None) is not None
+
+ chat_template_kwargs = extra_body.get("chat_template_kwargs")
+ if (
+ isinstance(chat_template_kwargs, dict)
+ and chat_template_kwargs.pop("reasoning_budget", None) is not None
+ ):
+ removed = True
+
+ if not extra_body:
+ cloned_body.pop("extra_body", None)
+
+ if not removed:
+ return None
+
+ return cloned_body
+
+
def build_request_body(
request_data: Any, nim: NimSettings, *, thinking_enabled: bool
) -> dict:
@@ -69,10 +95,11 @@ def build_request_body(
extra_body.update(request_extra)
if thinking_enabled:
- extra_body.setdefault(
+ chat_template_kwargs = extra_body.setdefault(
"chat_template_kwargs", {"thinking": True, "enable_thinking": True}
)
- _set_extra(extra_body, "reasoning_budget", max_tokens)
+ if isinstance(chat_template_kwargs, dict):
+ chat_template_kwargs.setdefault("reasoning_budget", max_tokens)
req_top_k = getattr(request_data, "top_k", None)
top_k = req_top_k if req_top_k is not None else nim.top_k
diff --git a/providers/openai_compat.py b/providers/openai_compat.py
index 76ff8cb..979ec3e 100644
--- a/providers/openai_compat.py
+++ b/providers/openai_compat.py
@@ -84,6 +84,27 @@ class OpenAICompatibleProvider(BaseProvider):
"""Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details)."""
return iter(())
+ def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
+ """Return a modified request body for one retry, or None."""
+ return None
+
+ async def _create_stream(self, body: dict) -> tuple[Any, dict]:
+ """Create a streaming chat completion, optionally retrying once."""
+ try:
+ stream = await self._global_rate_limiter.execute_with_retry(
+ self._client.chat.completions.create, **body, stream=True
+ )
+ return stream, body
+ except Exception as error:
+ retry_body = self._get_retry_request_body(error, body)
+ if retry_body is None:
+ raise
+
+ stream = await self._global_rate_limiter.execute_with_retry(
+ self._client.chat.completions.create, **retry_body, stream=True
+ )
+ return stream, retry_body
+
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
"""Process a single tool call delta and yield SSE events."""
tc_index = tc.get("index", 0)
@@ -174,9 +195,7 @@ class OpenAICompatibleProvider(BaseProvider):
async with self._global_rate_limiter.concurrency_slot():
try:
- stream = await self._global_rate_limiter.execute_with_retry(
- self._client.chat.completions.create, **body, stream=True
- )
+ stream, body = await self._create_stream(body)
async for chunk in stream:
if getattr(chunk, "usage", None):
usage_info = chunk.usage
diff --git a/tests/providers/test_nvidia_nim.py b/tests/providers/test_nvidia_nim.py
index 3f15095..d8875ab 100644
--- a/tests/providers/test_nvidia_nim.py
+++ b/tests/providers/test_nvidia_nim.py
@@ -1,7 +1,9 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
+import openai
import pytest
+from httpx import Request, Response
from providers.nvidia_nim import NvidiaNimProvider
@@ -105,7 +107,8 @@ async def test_build_request_body(provider_config):
ctk = body["extra_body"]["chat_template_kwargs"]
assert ctk["thinking"] is True
assert ctk["enable_thinking"] is True
- assert body["extra_body"]["reasoning_budget"] == body["max_tokens"]
+ assert ctk["reasoning_budget"] == body["max_tokens"]
+ assert "reasoning_budget" not in body["extra_body"]
@pytest.mark.asyncio
@@ -265,6 +268,12 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
assert "Answer" in event_text
+def _make_bad_request_error(message: str) -> openai.BadRequestError:
+ response = Response(status_code=400, request=Request("POST", "http://test"))
+ body = {"error": {"message": message}}
+ return openai.BadRequestError(message, response=response, body=body)
+
+
@pytest.mark.asyncio
async def test_tool_call_stream(nim_provider):
"""Test streaming tool calls."""
@@ -301,3 +310,61 @@ async def test_tool_call_stream(nim_provider):
]
assert len(starts) == 1
assert "search" in starts[0]
+
+
+@pytest.mark.asyncio
+async def test_stream_response_retries_without_reasoning_budget(nim_provider):
+ req = MockRequest()
+
+ mock_chunk = MagicMock()
+ mock_chunk.choices = [
+ MagicMock(
+ delta=MagicMock(content="Recovered", reasoning_content=""),
+ finish_reason="stop",
+ )
+ ]
+ mock_chunk.usage = MagicMock(completion_tokens=5)
+
+ async def mock_stream():
+ yield mock_chunk
+
+ error = _make_bad_request_error("Unsupported field: reasoning_budget")
+
+ with patch.object(
+ nim_provider._client.chat.completions, "create", new_callable=AsyncMock
+ ) as mock_create:
+ mock_create.side_effect = [error, mock_stream()]
+
+ events = [e async for e in nim_provider.stream_response(req)]
+
+ assert mock_create.await_count == 2
+ first_call = mock_create.await_args_list[0].kwargs
+ second_call = mock_create.await_args_list[1].kwargs
+ assert (
+ first_call["extra_body"]["chat_template_kwargs"]["reasoning_budget"]
+ == first_call["max_tokens"]
+ )
+ assert "reasoning_budget" not in second_call["extra_body"]
+ assert "reasoning_budget" not in second_call["extra_body"]["chat_template_kwargs"]
+ assert second_call["extra_body"]["chat_template_kwargs"]["enable_thinking"] is True
+ assert any("Recovered" in event for event in events)
+ assert any("message_stop" in event for event in events)
+
+
+@pytest.mark.asyncio
+async def test_stream_response_bad_request_without_reasoning_budget_does_not_retry(
+ nim_provider,
+):
+ req = MockRequest()
+ error = _make_bad_request_error("Unsupported field: top_k")
+
+ with patch.object(
+ nim_provider._client.chat.completions, "create", new_callable=AsyncMock
+ ) as mock_create:
+ mock_create.side_effect = error
+
+ events = [e async for e in nim_provider.stream_response(req)]
+
+ assert mock_create.await_count == 1
+ assert any("Unsupported field: top_k" in event for event in events)
+ assert any("message_stop" in event for event in events)
diff --git a/tests/providers/test_nvidia_nim_request.py b/tests/providers/test_nvidia_nim_request.py
index b7508d6..2dd040f 100644
--- a/tests/providers/test_nvidia_nim_request.py
+++ b/tests/providers/test_nvidia_nim_request.py
@@ -6,10 +6,7 @@ import pytest
from config.nim import NimSettings
from providers.common.utils import set_if_not_none
-from providers.nvidia_nim.request import (
- _set_extra,
- build_request_body,
-)
+from providers.nvidia_nim.request import _set_extra, build_request_body
@pytest.fixture
@@ -104,8 +101,9 @@ class TestBuildRequestBody:
assert extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
+ "reasoning_budget": body["max_tokens"],
}
- assert extra["reasoning_budget"] == body["max_tokens"]
+ assert "reasoning_budget" not in extra
def test_no_chat_template_kwargs_when_thinking_disabled(self):
req = MagicMock()
@@ -127,6 +125,29 @@ class TestBuildRequestBody:
assert "chat_template_kwargs" not in extra
assert "reasoning_budget" not in extra
+ def test_reasoning_budget_respects_existing_chat_template_kwargs(self):
+ req = MagicMock()
+ req.model = "test"
+ req.messages = [MagicMock(role="user", content="hi")]
+ req.max_tokens = 100
+ req.system = None
+ req.temperature = None
+ req.top_p = None
+ req.stop_sequences = None
+ req.tools = None
+ req.tool_choice = None
+ req.top_k = None
+ req.extra_body = {
+ "chat_template_kwargs": {"enable_thinking": False, "custom": "value"}
+ }
+
+ body = build_request_body(req, NimSettings(), thinking_enabled=True)
+ assert body["extra_body"]["chat_template_kwargs"] == {
+ "enable_thinking": False,
+ "custom": "value",
+ "reasoning_budget": body["max_tokens"],
+ }
+
def test_no_reasoning_params_in_extra_body(self):
req = MagicMock()
req.model = "test"