mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Fixes #130. This PR updates the NVIDIA NIM provider to omit \chat_template_kwargs\ and \chat_template\ when using a Mistral tokenizer model. This resolves the 400 Bad Request error returned by the API. Co-authored-by: Alishahryar1 <alishahryar2@gmail.com>
This commit is contained in:
parent
4afca05318
commit
4fdf7e8b7e
4 changed files with 183 additions and 14 deletions
|
|
@ -39,6 +39,15 @@ class MockRequest:
|
|||
setattr(self, k, v)
|
||||
|
||||
|
||||
def _make_bad_request_error(message: str) -> openai.BadRequestError:
|
||||
response = Response(
|
||||
status_code=400,
|
||||
request=Request("POST", "https://integrate.api.nvidia.com/v1/chat/completions"),
|
||||
)
|
||||
body = {"error": {"message": message, "type": "BadRequestError", "code": 400}}
|
||||
return openai.BadRequestError(message, response=response, body=body)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_rate_limiter():
|
||||
"""Mock the global rate limiter to prevent waiting."""
|
||||
|
|
@ -268,10 +277,86 @@ async def test_stream_response_suppresses_thinking_when_disabled(provider_config
|
|||
assert "Answer" in event_text
|
||||
|
||||
|
||||
def _make_bad_request_error(message: str) -> openai.BadRequestError:
|
||||
response = Response(status_code=400, request=Request("POST", "http://test"))
|
||||
body = {"error": {"message": message}}
|
||||
return openai.BadRequestError(message, response=response, body=body)
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_retries_without_chat_template(provider_config):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config,
|
||||
nim_settings=NimSettings(chat_template="custom_template"),
|
||||
)
|
||||
req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1")
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(content="OK", reasoning_content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=2)
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
first_error = _make_bad_request_error(
|
||||
"chat_template is not supported for Mistral tokenizers."
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.side_effect = [first_error, mock_stream()]
|
||||
|
||||
events = [e async for e in provider.stream_response(req)]
|
||||
|
||||
assert mock_create.await_count == 2
|
||||
|
||||
first_extra = mock_create.call_args_list[0].kwargs["extra_body"]
|
||||
second_extra = mock_create.call_args_list[1].kwargs["extra_body"]
|
||||
|
||||
assert first_extra["chat_template"] == "custom_template"
|
||||
assert first_extra["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
"enable_thinking": True,
|
||||
"reasoning_budget": 100,
|
||||
}
|
||||
assert "reasoning_budget" not in first_extra
|
||||
|
||||
assert "chat_template" not in second_extra
|
||||
assert second_extra["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
"enable_thinking": True,
|
||||
"reasoning_budget": 100,
|
||||
}
|
||||
assert "reasoning_budget" not in second_extra
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "event: error" not in event_text
|
||||
assert "OK" in event_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_does_not_retry_unrelated_bad_request(provider_config):
|
||||
from config.nim import NimSettings
|
||||
|
||||
provider = NvidiaNimProvider(
|
||||
provider_config,
|
||||
nim_settings=NimSettings(chat_template="custom_template"),
|
||||
)
|
||||
req = MockRequest(model="mistralai/mixtral-8x7b-instruct-v0.1")
|
||||
|
||||
with patch.object(
|
||||
provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.side_effect = _make_bad_request_error("unrelated bad request")
|
||||
|
||||
events = [e async for e in provider.stream_response(req)]
|
||||
|
||||
assert mock_create.await_count == 1
|
||||
event_text = "".join(events)
|
||||
assert "unrelated bad request" in event_text
|
||||
assert "event: message_stop" in event_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -6,7 +6,11 @@ import pytest
|
|||
|
||||
from config.nim import NimSettings
|
||||
from providers.common.utils import set_if_not_none
|
||||
from providers.nvidia_nim.request import _set_extra, build_request_body
|
||||
from providers.nvidia_nim.request import (
|
||||
_set_extra,
|
||||
build_request_body,
|
||||
clone_body_without_chat_template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -105,6 +109,32 @@ class TestBuildRequestBody:
|
|||
}
|
||||
assert "reasoning_budget" not in extra
|
||||
|
||||
def test_clone_body_without_chat_template(self):
|
||||
body = {
|
||||
"model": "test",
|
||||
"extra_body": {
|
||||
"chat_template": "custom_template",
|
||||
"chat_template_kwargs": {
|
||||
"thinking": True,
|
||||
"enable_thinking": True,
|
||||
"reasoning_budget": 100,
|
||||
},
|
||||
"ignore_eos": False,
|
||||
},
|
||||
}
|
||||
|
||||
cloned = clone_body_without_chat_template(body)
|
||||
|
||||
assert cloned is not None
|
||||
assert "chat_template" not in cloned["extra_body"]
|
||||
assert cloned["extra_body"]["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
"enable_thinking": True,
|
||||
"reasoning_budget": 100,
|
||||
}
|
||||
assert cloned["extra_body"]["ignore_eos"] is False
|
||||
assert body["extra_body"]["chat_template"] == "custom_template"
|
||||
|
||||
def test_no_chat_template_kwargs_when_thinking_disabled(self):
|
||||
req = MagicMock()
|
||||
req.model = "test"
|
||||
|
|
@ -148,6 +178,30 @@ class TestBuildRequestBody:
|
|||
"reasoning_budget": body["max_tokens"],
|
||||
}
|
||||
|
||||
def test_chat_template_fields_present_for_mistral_model(self):
|
||||
req = MagicMock()
|
||||
req.model = "mistralai/mixtral-8x7b-instruct-v0.1"
|
||||
req.messages = [MagicMock(role="user", content="hi")]
|
||||
req.max_tokens = 100
|
||||
req.system = None
|
||||
req.temperature = None
|
||||
req.top_p = None
|
||||
req.stop_sequences = None
|
||||
req.tools = None
|
||||
req.tool_choice = None
|
||||
req.extra_body = None
|
||||
req.top_k = None
|
||||
|
||||
nim = NimSettings(chat_template="custom_template")
|
||||
body = build_request_body(req, nim, thinking_enabled=True)
|
||||
extra = body.get("extra_body", {})
|
||||
assert extra["chat_template_kwargs"] == {
|
||||
"thinking": True,
|
||||
"enable_thinking": True,
|
||||
"reasoning_budget": body["max_tokens"],
|
||||
}
|
||||
assert extra["chat_template"] == "custom_template"
|
||||
|
||||
def test_no_reasoning_params_in_extra_body(self):
|
||||
req = MagicMock()
|
||||
req.model = "test"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue