bug: nvidia didn't not support reasoning_budget parameter (#126)

<img width="2538" height="411" alt="image"
src="https://github.com/user-attachments/assets/8fc07f00-8869-4548-b40a-a36a15e4e043"
/>

Fixes #127.

---------

Co-authored-by: u011436427 <u011436427@noreply.gitcode.com>
Co-authored-by: Alishahryar1 <alishahryar2@gmail.com>
This commit is contained in:
Wang Ji 2026-04-23 08:06:46 +08:00 committed by GitHub
parent 2fe15bd2cd
commit 4afca05318
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 170 additions and 12 deletions

View file

@ -1,12 +1,16 @@
"""NVIDIA NIM provider implementation."""
import json
from typing import Any
import openai
from loguru import logger
from config.nim import NimSettings
from providers.base import ProviderConfig
from providers.openai_compat import OpenAICompatibleProvider
from .request import build_request_body
from .request import build_request_body, clone_body_without_reasoning_budget
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
@ -30,3 +34,23 @@ class NvidiaNimProvider(OpenAICompatibleProvider):
self._nim_settings,
thinking_enabled=self._is_thinking_enabled(request),
)
def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
"""Retry once without reasoning_budget when NIM rejects that field."""
status_code = getattr(error, "status_code", None)
if not isinstance(error, openai.BadRequestError) and status_code != 400:
return None
error_text = str(error)
error_body = getattr(error, "body", None)
if error_body is not None:
error_text = f"{error_text} {json.dumps(error_body, default=str)}"
if "reasoning_budget" not in error_text.lower():
return None
retry_body = clone_body_without_reasoning_budget(body)
if retry_body is None:
return None
logger.warning("NIM_STREAM: retrying without reasoning_budget after 400 error")
return retry_body

View file

@ -1,5 +1,6 @@
"""Request builder for NVIDIA NIM provider."""
from copy import deepcopy
from typing import Any
from loguru import logger
@ -21,6 +22,31 @@ def _set_extra(
extra_body[key] = value
def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None:
"""Clone a request body and strip only reasoning_budget fields."""
cloned_body = deepcopy(body)
extra_body = cloned_body.get("extra_body")
if not isinstance(extra_body, dict):
return None
removed = extra_body.pop("reasoning_budget", None) is not None
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if (
isinstance(chat_template_kwargs, dict)
and chat_template_kwargs.pop("reasoning_budget", None) is not None
):
removed = True
if not extra_body:
cloned_body.pop("extra_body", None)
if not removed:
return None
return cloned_body
def build_request_body(
request_data: Any, nim: NimSettings, *, thinking_enabled: bool
) -> dict:
@ -69,10 +95,11 @@ def build_request_body(
extra_body.update(request_extra)
if thinking_enabled:
extra_body.setdefault(
chat_template_kwargs = extra_body.setdefault(
"chat_template_kwargs", {"thinking": True, "enable_thinking": True}
)
_set_extra(extra_body, "reasoning_budget", max_tokens)
if isinstance(chat_template_kwargs, dict):
chat_template_kwargs.setdefault("reasoning_budget", max_tokens)
req_top_k = getattr(request_data, "top_k", None)
top_k = req_top_k if req_top_k is not None else nim.top_k

View file

@ -84,6 +84,27 @@ class OpenAICompatibleProvider(BaseProvider):
"""Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details)."""
return iter(())
def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
"""Return a modified request body for one retry, or None."""
return None
async def _create_stream(self, body: dict) -> tuple[Any, dict]:
"""Create a streaming chat completion, optionally retrying once."""
try:
stream = await self._global_rate_limiter.execute_with_retry(
self._client.chat.completions.create, **body, stream=True
)
return stream, body
except Exception as error:
retry_body = self._get_retry_request_body(error, body)
if retry_body is None:
raise
stream = await self._global_rate_limiter.execute_with_retry(
self._client.chat.completions.create, **retry_body, stream=True
)
return stream, retry_body
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
"""Process a single tool call delta and yield SSE events."""
tc_index = tc.get("index", 0)
@ -174,9 +195,7 @@ class OpenAICompatibleProvider(BaseProvider):
async with self._global_rate_limiter.concurrency_slot():
try:
stream = await self._global_rate_limiter.execute_with_retry(
self._client.chat.completions.create, **body, stream=True
)
stream, body = await self._create_stream(body)
async for chunk in stream:
if getattr(chunk, "usage", None):
usage_info = chunk.usage

View file

@ -1,7 +1,9 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import openai
import pytest
from httpx import Request, Response
from providers.nvidia_nim import NvidiaNimProvider
@ -105,7 +107,8 @@ async def test_build_request_body(provider_config):
ctk = body["extra_body"]["chat_template_kwargs"]
assert ctk["thinking"] is True
assert ctk["enable_thinking"] is True
assert body["extra_body"]["reasoning_budget"] == body["max_tokens"]
assert ctk["reasoning_budget"] == body["max_tokens"]
assert "reasoning_budget" not in body["extra_body"]
@pytest.mark.asyncio
@ -265,6 +268,12 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
assert "Answer" in event_text
def _make_bad_request_error(message: str) -> openai.BadRequestError:
response = Response(status_code=400, request=Request("POST", "http://test"))
body = {"error": {"message": message}}
return openai.BadRequestError(message, response=response, body=body)
@pytest.mark.asyncio
async def test_tool_call_stream(nim_provider):
"""Test streaming tool calls."""
@ -301,3 +310,61 @@ async def test_tool_call_stream(nim_provider):
]
assert len(starts) == 1
assert "search" in starts[0]
@pytest.mark.asyncio
async def test_stream_response_retries_without_reasoning_budget(nim_provider):
req = MockRequest()
mock_chunk = MagicMock()
mock_chunk.choices = [
MagicMock(
delta=MagicMock(content="Recovered", reasoning_content=""),
finish_reason="stop",
)
]
mock_chunk.usage = MagicMock(completion_tokens=5)
async def mock_stream():
yield mock_chunk
error = _make_bad_request_error("Unsupported field: reasoning_budget")
with patch.object(
nim_provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.side_effect = [error, mock_stream()]
events = [e async for e in nim_provider.stream_response(req)]
assert mock_create.await_count == 2
first_call = mock_create.await_args_list[0].kwargs
second_call = mock_create.await_args_list[1].kwargs
assert (
first_call["extra_body"]["chat_template_kwargs"]["reasoning_budget"]
== first_call["max_tokens"]
)
assert "reasoning_budget" not in second_call["extra_body"]
assert "reasoning_budget" not in second_call["extra_body"]["chat_template_kwargs"]
assert second_call["extra_body"]["chat_template_kwargs"]["enable_thinking"] is True
assert any("Recovered" in event for event in events)
assert any("message_stop" in event for event in events)
@pytest.mark.asyncio
async def test_stream_response_bad_request_without_reasoning_budget_does_not_retry(
nim_provider,
):
req = MockRequest()
error = _make_bad_request_error("Unsupported field: top_k")
with patch.object(
nim_provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.side_effect = error
events = [e async for e in nim_provider.stream_response(req)]
assert mock_create.await_count == 1
assert any("Unsupported field: top_k" in event for event in events)
assert any("message_stop" in event for event in events)

View file

@ -6,10 +6,7 @@ import pytest
from config.nim import NimSettings
from providers.common.utils import set_if_not_none
from providers.nvidia_nim.request import (
_set_extra,
build_request_body,
)
from providers.nvidia_nim.request import _set_extra, build_request_body
@pytest.fixture
@ -104,8 +101,9 @@ class TestBuildRequestBody:
assert extra["chat_template_kwargs"] == {
"thinking": True,
"enable_thinking": True,
"reasoning_budget": body["max_tokens"],
}
assert extra["reasoning_budget"] == body["max_tokens"]
assert "reasoning_budget" not in extra
def test_no_chat_template_kwargs_when_thinking_disabled(self):
req = MagicMock()
@ -127,6 +125,29 @@ class TestBuildRequestBody:
assert "chat_template_kwargs" not in extra
assert "reasoning_budget" not in extra
def test_reasoning_budget_respects_existing_chat_template_kwargs(self):
req = MagicMock()
req.model = "test"
req.messages = [MagicMock(role="user", content="hi")]
req.max_tokens = 100
req.system = None
req.temperature = None
req.top_p = None
req.stop_sequences = None
req.tools = None
req.tool_choice = None
req.top_k = None
req.extra_body = {
"chat_template_kwargs": {"enable_thinking": False, "custom": "value"}
}
body = build_request_body(req, NimSettings(), thinking_enabled=True)
assert body["extra_body"]["chat_template_kwargs"] == {
"enable_thinking": False,
"custom": "value",
"reasoning_budget": body["max_tokens"],
}
def test_no_reasoning_params_in_extra_body(self):
req = MagicMock()
req.model = "test"