mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
bug: nvidia didn't not support reasoning_budget parameter (#126)
<img width="2538" height="411" alt="image" src="https://github.com/user-attachments/assets/8fc07f00-8869-4548-b40a-a36a15e4e043" /> Fixes #127. --------- Co-authored-by: u011436427 <u011436427@noreply.gitcode.com> Co-authored-by: Alishahryar1 <alishahryar2@gmail.com>
This commit is contained in:
parent
2fe15bd2cd
commit
4afca05318
5 changed files with 170 additions and 12 deletions
|
|
@ -1,12 +1,16 @@
|
|||
"""NVIDIA NIM provider implementation."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from loguru import logger
|
||||
|
||||
from config.nim import NimSettings
|
||||
from providers.base import ProviderConfig
|
||||
from providers.openai_compat import OpenAICompatibleProvider
|
||||
|
||||
from .request import build_request_body
|
||||
from .request import build_request_body, clone_body_without_reasoning_budget
|
||||
|
||||
NVIDIA_NIM_BASE_URL = "https://integrate.api.nvidia.com/v1"
|
||||
|
||||
|
|
@ -30,3 +34,23 @@ class NvidiaNimProvider(OpenAICompatibleProvider):
|
|||
self._nim_settings,
|
||||
thinking_enabled=self._is_thinking_enabled(request),
|
||||
)
|
||||
|
||||
def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
|
||||
"""Retry once without reasoning_budget when NIM rejects that field."""
|
||||
status_code = getattr(error, "status_code", None)
|
||||
if not isinstance(error, openai.BadRequestError) and status_code != 400:
|
||||
return None
|
||||
|
||||
error_text = str(error)
|
||||
error_body = getattr(error, "body", None)
|
||||
if error_body is not None:
|
||||
error_text = f"{error_text} {json.dumps(error_body, default=str)}"
|
||||
if "reasoning_budget" not in error_text.lower():
|
||||
return None
|
||||
|
||||
retry_body = clone_body_without_reasoning_budget(body)
|
||||
if retry_body is None:
|
||||
return None
|
||||
|
||||
logger.warning("NIM_STREAM: retrying without reasoning_budget after 400 error")
|
||||
return retry_body
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Request builder for NVIDIA NIM provider."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -21,6 +22,31 @@ def _set_extra(
|
|||
extra_body[key] = value
|
||||
|
||||
|
||||
def clone_body_without_reasoning_budget(body: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Clone a request body and strip only reasoning_budget fields."""
|
||||
cloned_body = deepcopy(body)
|
||||
extra_body = cloned_body.get("extra_body")
|
||||
if not isinstance(extra_body, dict):
|
||||
return None
|
||||
|
||||
removed = extra_body.pop("reasoning_budget", None) is not None
|
||||
|
||||
chat_template_kwargs = extra_body.get("chat_template_kwargs")
|
||||
if (
|
||||
isinstance(chat_template_kwargs, dict)
|
||||
and chat_template_kwargs.pop("reasoning_budget", None) is not None
|
||||
):
|
||||
removed = True
|
||||
|
||||
if not extra_body:
|
||||
cloned_body.pop("extra_body", None)
|
||||
|
||||
if not removed:
|
||||
return None
|
||||
|
||||
return cloned_body
|
||||
|
||||
|
||||
def build_request_body(
|
||||
request_data: Any, nim: NimSettings, *, thinking_enabled: bool
|
||||
) -> dict:
|
||||
|
|
@ -69,10 +95,11 @@ def build_request_body(
|
|||
extra_body.update(request_extra)
|
||||
|
||||
if thinking_enabled:
|
||||
extra_body.setdefault(
|
||||
chat_template_kwargs = extra_body.setdefault(
|
||||
"chat_template_kwargs", {"thinking": True, "enable_thinking": True}
|
||||
)
|
||||
_set_extra(extra_body, "reasoning_budget", max_tokens)
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
chat_template_kwargs.setdefault("reasoning_budget", max_tokens)
|
||||
|
||||
req_top_k = getattr(request_data, "top_k", None)
|
||||
top_k = req_top_k if req_top_k is not None else nim.top_k
|
||||
|
|
|
|||
|
|
@ -84,6 +84,27 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
"""Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details)."""
|
||||
return iter(())
|
||||
|
||||
def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
|
||||
"""Return a modified request body for one retry, or None."""
|
||||
return None
|
||||
|
||||
async def _create_stream(self, body: dict) -> tuple[Any, dict]:
|
||||
"""Create a streaming chat completion, optionally retrying once."""
|
||||
try:
|
||||
stream = await self._global_rate_limiter.execute_with_retry(
|
||||
self._client.chat.completions.create, **body, stream=True
|
||||
)
|
||||
return stream, body
|
||||
except Exception as error:
|
||||
retry_body = self._get_retry_request_body(error, body)
|
||||
if retry_body is None:
|
||||
raise
|
||||
|
||||
stream = await self._global_rate_limiter.execute_with_retry(
|
||||
self._client.chat.completions.create, **retry_body, stream=True
|
||||
)
|
||||
return stream, retry_body
|
||||
|
||||
def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]:
|
||||
"""Process a single tool call delta and yield SSE events."""
|
||||
tc_index = tc.get("index", 0)
|
||||
|
|
@ -174,9 +195,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
|
||||
async with self._global_rate_limiter.concurrency_slot():
|
||||
try:
|
||||
stream = await self._global_rate_limiter.execute_with_retry(
|
||||
self._client.chat.completions.create, **body, stream=True
|
||||
)
|
||||
stream, body = await self._create_stream(body)
|
||||
async for chunk in stream:
|
||||
if getattr(chunk, "usage", None):
|
||||
usage_info = chunk.usage
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from httpx import Request, Response
|
||||
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
|
|
@ -105,7 +107,8 @@ async def test_build_request_body(provider_config):
|
|||
ctk = body["extra_body"]["chat_template_kwargs"]
|
||||
assert ctk["thinking"] is True
|
||||
assert ctk["enable_thinking"] is True
|
||||
assert body["extra_body"]["reasoning_budget"] == body["max_tokens"]
|
||||
assert ctk["reasoning_budget"] == body["max_tokens"]
|
||||
assert "reasoning_budget" not in body["extra_body"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -265,6 +268,12 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
|
|||
assert "Answer" in event_text
|
||||
|
||||
|
||||
def _make_bad_request_error(message: str) -> openai.BadRequestError:
|
||||
response = Response(status_code=400, request=Request("POST", "http://test"))
|
||||
body = {"error": {"message": message}}
|
||||
return openai.BadRequestError(message, response=response, body=body)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_stream(nim_provider):
|
||||
"""Test streaming tool calls."""
|
||||
|
|
@ -301,3 +310,61 @@ async def test_tool_call_stream(nim_provider):
|
|||
]
|
||||
assert len(starts) == 1
|
||||
assert "search" in starts[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_retries_without_reasoning_budget(nim_provider):
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(content="Recovered", reasoning_content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=5)
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
error = _make_bad_request_error("Unsupported field: reasoning_budget")
|
||||
|
||||
with patch.object(
|
||||
nim_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.side_effect = [error, mock_stream()]
|
||||
|
||||
events = [e async for e in nim_provider.stream_response(req)]
|
||||
|
||||
assert mock_create.await_count == 2
|
||||
first_call = mock_create.await_args_list[0].kwargs
|
||||
second_call = mock_create.await_args_list[1].kwargs
|
||||
assert (
|
||||
first_call["extra_body"]["chat_template_kwargs"]["reasoning_budget"]
|
||||
== first_call["max_tokens"]
|
||||
)
|
||||
assert "reasoning_budget" not in second_call["extra_body"]
|
||||
assert "reasoning_budget" not in second_call["extra_body"]["chat_template_kwargs"]
|
||||
assert second_call["extra_body"]["chat_template_kwargs"]["enable_thinking"] is True
|
||||
assert any("Recovered" in event for event in events)
|
||||
assert any("message_stop" in event for event in events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_bad_request_without_reasoning_budget_does_not_retry(
|
||||
nim_provider,
|
||||
):
|
||||
req = MockRequest()
|
||||
error = _make_bad_request_error("Unsupported field: top_k")
|
||||
|
||||
with patch.object(
|
||||
nim_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.side_effect = error
|
||||
|
||||
events = [e async for e in nim_provider.stream_response(req)]
|
||||
|
||||
assert mock_create.await_count == 1
|
||||
assert any("Unsupported field: top_k" in event for event in events)
|
||||
assert any("message_stop" in event for event in events)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,7 @@ import pytest
|
|||
|
||||
from config.nim import NimSettings
|
||||
from providers.common.utils import set_if_not_none
|
||||
from providers.nvidia_nim.request import (
|
||||
_set_extra,
|
||||
build_request_body,
|
||||
)
|
||||
from providers.nvidia_nim.request import _set_extra, build_request_body
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -104,8 +101,9 @@ class TestBuildRequestBody:
|
|||
assert extra["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
"enable_thinking": True,
|
||||
"reasoning_budget": body["max_tokens"],
|
||||
}
|
||||
assert extra["reasoning_budget"] == body["max_tokens"]
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
def test_no_chat_template_kwargs_when_thinking_disabled(self):
|
||||
req = MagicMock()
|
||||
|
|
@ -127,6 +125,29 @@ class TestBuildRequestBody:
|
|||
assert "chat_template_kwargs" not in extra
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
def test_reasoning_budget_respects_existing_chat_template_kwargs(self):
|
||||
req = MagicMock()
|
||||
req.model = "test"
|
||||
req.messages = [MagicMock(role="user", content="hi")]
|
||||
req.max_tokens = 100
|
||||
req.system = None
|
||||
req.temperature = None
|
||||
req.top_p = None
|
||||
req.stop_sequences = None
|
||||
req.tools = None
|
||||
req.tool_choice = None
|
||||
req.top_k = None
|
||||
req.extra_body = {
|
||||
"chat_template_kwargs": {"enable_thinking": False, "custom": "value"}
|
||||
}
|
||||
|
||||
body = build_request_body(req, NimSettings(), thinking_enabled=True)
|
||||
assert body["extra_body"]["chat_template_kwargs"] == {
|
||||
"enable_thinking": False,
|
||||
"custom": "value",
|
||||
"reasoning_budget": body["max_tokens"],
|
||||
}
|
||||
|
||||
def test_no_reasoning_params_in_extra_body(self):
|
||||
req = MagicMock()
|
||||
req.model = "test"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue