mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
added more tests
This commit is contained in:
parent
c97b2aaa55
commit
4fb585ed5b
5 changed files with 504 additions and 1 deletions
|
|
@ -57,6 +57,30 @@ class TestSettings:
|
|||
|
||||
assert NVIDIA_NIM_BASE_URL == "https://integrate.api.nvidia.com/v1"
|
||||
|
||||
def test_lm_studio_base_url_from_env(self, monkeypatch):
|
||||
"""LM_STUDIO_BASE_URL env var is loaded into settings."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("LM_STUDIO_BASE_URL", "http://custom:5678/v1")
|
||||
settings = Settings()
|
||||
assert settings.lm_studio_base_url == "http://custom:5678/v1"
|
||||
|
||||
def test_provider_rate_limit_from_env(self, monkeypatch):
|
||||
"""PROVIDER_RATE_LIMIT env var is loaded into settings."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("PROVIDER_RATE_LIMIT", "20")
|
||||
settings = Settings()
|
||||
assert settings.provider_rate_limit == 20
|
||||
|
||||
def test_provider_rate_window_from_env(self, monkeypatch):
|
||||
"""PROVIDER_RATE_WINDOW env var is loaded into settings."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("PROVIDER_RATE_WINDOW", "30")
|
||||
settings = Settings()
|
||||
assert settings.provider_rate_window == 30
|
||||
|
||||
|
||||
# --- NimSettings Validation Tests ---
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,21 @@ async def test_get_provider_lmstudio():
|
|||
assert provider._api_key == "lm-studio"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_lmstudio_uses_lm_studio_base_url():
|
||||
"""LM Studio provider uses lm_studio_base_url from settings."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings(
|
||||
provider_type="lmstudio",
|
||||
lm_studio_base_url="http://custom:9999/v1",
|
||||
)
|
||||
|
||||
provider = get_provider()
|
||||
|
||||
assert isinstance(provider, LMStudioProvider)
|
||||
assert provider._base_url == "http://custom:9999/v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_unknown_type():
|
||||
"""Test that unknown provider_type raises ValueError."""
|
||||
|
|
|
|||
|
|
@ -6,9 +6,72 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
from providers.base import ProviderConfig
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.lmstudio.request import LMSTUDIO_DEFAULT_MAX_TOKENS
|
||||
from config.nim import NimSettings
|
||||
|
||||
|
||||
class AsyncStreamMock:
|
||||
"""Async iterable mock that yields chunks then optionally raises."""
|
||||
|
||||
def __init__(self, chunks, error=None):
|
||||
self._chunks = chunks
|
||||
self._error = error
|
||||
|
||||
def __aiter__(self):
|
||||
return self._aiter()
|
||||
|
||||
async def _aiter(self):
|
||||
for chunk in self._chunks:
|
||||
yield chunk
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
content=None, finish_reason=None, tool_calls=None, reasoning_content=None
|
||||
):
|
||||
"""Create a mock streaming chunk."""
|
||||
delta = MagicMock()
|
||||
delta.content = content
|
||||
delta.tool_calls = tool_calls
|
||||
delta.reasoning_content = reasoning_content if reasoning_content else None
|
||||
|
||||
choice = MagicMock()
|
||||
choice.delta = delta
|
||||
choice.finish_reason = finish_reason
|
||||
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [choice]
|
||||
chunk.usage = None
|
||||
return chunk
|
||||
|
||||
|
||||
def _make_request(model="test-model", **kwargs):
|
||||
"""Create a mock request with all fields build_request_body needs."""
|
||||
req = MagicMock()
|
||||
req.model = model
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.system = None
|
||||
req.max_tokens = 100
|
||||
req.temperature = None
|
||||
req.top_p = None
|
||||
req.stop_sequences = None
|
||||
req.tools = None
|
||||
req.tool_choice = None
|
||||
req.thinking = MagicMock(enabled=True)
|
||||
for k, v in kwargs.items():
|
||||
setattr(req, k, v)
|
||||
return req
|
||||
|
||||
|
||||
async def _collect_stream(provider, request):
|
||||
"""Collect all SSE events from a stream."""
|
||||
events = []
|
||||
async for event in provider.stream_response(request):
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, role, content):
|
||||
self.role = role
|
||||
|
|
@ -190,3 +253,395 @@ async def test_stream_response_reasoning_content(lmstudio_provider):
|
|||
if "Thinking..." in e:
|
||||
found_thinking = True
|
||||
assert found_thinking
|
||||
|
||||
|
||||
# --- Stream Error Handling ---
|
||||
|
||||
|
||||
class TestLMStudioStreamingExceptionHandling:
|
||||
"""Tests for error paths during stream_response."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_emits_sse_error_event(self, lmstudio_provider):
|
||||
"""When API raises during streaming, SSE error event is emitted."""
|
||||
request = _make_request()
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("API failed"),
|
||||
):
|
||||
events = await _collect_stream(lmstudio_provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "message_start" in event_text
|
||||
assert "API failed" in event_text
|
||||
assert "message_stop" in event_text
|
||||
assert "[DONE]" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_after_partial_content(self, lmstudio_provider):
|
||||
"""Error after partial content: blocks closed, error emitted."""
|
||||
request = _make_request()
|
||||
chunk1 = _make_chunk(content="Hello ")
|
||||
stream_mock = AsyncStreamMock(
|
||||
[chunk1], error=ConnectionResetError("Connection lost")
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
):
|
||||
events = await _collect_stream(lmstudio_provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "Hello" in event_text
|
||||
assert "Connection lost" in event_text
|
||||
assert "message_stop" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_gets_space(self, lmstudio_provider):
|
||||
"""Empty response with no text/tools gets a single space text block."""
|
||||
request = _make_request()
|
||||
empty_chunk = _make_chunk(finish_reason="stop")
|
||||
stream_mock = AsyncStreamMock([empty_chunk])
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
):
|
||||
events = await _collect_stream(lmstudio_provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert '"text_delta"' in event_text
|
||||
assert "message_stop" in event_text
|
||||
|
||||
|
||||
# --- Stream Chunk Edge Cases ---
|
||||
|
||||
|
||||
class TestLMStudioStreamChunkEdgeCases:
|
||||
"""Tests for edge cases in stream chunk handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_chunk_with_empty_choices_skipped(self, lmstudio_provider):
|
||||
"""Chunk with choices=[] is skipped without crashing."""
|
||||
request = _make_request()
|
||||
empty_choices_chunk = MagicMock()
|
||||
empty_choices_chunk.choices = []
|
||||
empty_choices_chunk.usage = None
|
||||
finish_chunk = _make_chunk(finish_reason="stop")
|
||||
stream_mock = AsyncStreamMock([empty_choices_chunk, finish_chunk])
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
):
|
||||
events = await _collect_stream(lmstudio_provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "message_start" in event_text
|
||||
assert "message_stop" in event_text
|
||||
assert "[DONE]" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_chunk_with_none_delta_handled(self, lmstudio_provider):
|
||||
"""Chunk with choice.delta=None is handled defensively."""
|
||||
request = _make_request()
|
||||
none_delta_chunk = MagicMock()
|
||||
none_delta_chunk.usage = None
|
||||
choice = MagicMock()
|
||||
choice.delta = None
|
||||
choice.finish_reason = None
|
||||
none_delta_chunk.choices = [choice]
|
||||
finish_chunk = _make_chunk(finish_reason="stop")
|
||||
stream_mock = AsyncStreamMock([none_delta_chunk, finish_chunk])
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
):
|
||||
events = await _collect_stream(lmstudio_provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "message_start" in event_text
|
||||
assert "message_stop" in event_text
|
||||
assert "[DONE]" in event_text
|
||||
|
||||
|
||||
# --- Native Tool Calls ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_tool_call(lmstudio_provider):
|
||||
"""Test streaming tool calls."""
|
||||
request = _make_request()
|
||||
mock_tc = MagicMock()
|
||||
mock_tc.index = 0
|
||||
mock_tc.id = "call_1"
|
||||
mock_tc.function.name = "search"
|
||||
mock_tc.function.arguments = '{"q": "test"}'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(content=None, reasoning_content=None, tool_calls=[mock_tc]),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = None
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_stream(),
|
||||
):
|
||||
events = []
|
||||
async for event in lmstudio_provider.stream_response(request):
|
||||
events.append(event)
|
||||
|
||||
starts = [
|
||||
e for e in events if "event: content_block_start" in e and '"tool_use"' in e
|
||||
]
|
||||
assert len(starts) == 1
|
||||
assert "search" in starts[0]
|
||||
|
||||
|
||||
# --- Think Tag Parsing ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_think_tag_parsing(lmstudio_provider):
|
||||
"""Thinking content via think tags is emitted as thinking blocks."""
|
||||
request = _make_request()
|
||||
chunk1 = _make_chunk(content="<think>reasoning</think>answer")
|
||||
chunk2 = _make_chunk(finish_reason="stop")
|
||||
stream_mock = AsyncStreamMock([chunk1, chunk2])
|
||||
|
||||
with patch.object(
|
||||
lmstudio_provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=stream_mock,
|
||||
):
|
||||
events = await _collect_stream(lmstudio_provider, request)
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "thinking" in event_text
|
||||
assert "reasoning" in event_text
|
||||
assert "answer" in event_text
|
||||
|
||||
|
||||
# --- _process_tool_call and _flush_task_arg_buffers ---
|
||||
|
||||
|
||||
class TestLMStudioProcessToolCall:
|
||||
"""Tests for _process_tool_call method."""
|
||||
|
||||
def test_tool_call_with_id(self, lmstudio_provider):
|
||||
"""Tool call with id starts a tool block."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc = {
|
||||
"index": 0,
|
||||
"id": "call_123",
|
||||
"function": {"name": "search", "arguments": '{"q": "test"}'},
|
||||
}
|
||||
events = list(lmstudio_provider._process_tool_call(tc, sse))
|
||||
event_text = "".join(events)
|
||||
assert "tool_use" in event_text
|
||||
assert "search" in event_text
|
||||
assert "call_123" in event_text
|
||||
|
||||
def test_tool_call_without_id_generates_uuid(self, lmstudio_provider):
|
||||
"""Tool call without id generates a uuid-based id."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc = {
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"function": {"name": "test", "arguments": "{}"},
|
||||
}
|
||||
events = list(lmstudio_provider._process_tool_call(tc, sse))
|
||||
event_text = "".join(events)
|
||||
assert "tool_" in event_text
|
||||
|
||||
def test_task_tool_forces_background_false(self, lmstudio_provider):
|
||||
"""Task tool with run_in_background=true is forced to false."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
args = json.dumps({"run_in_background": True, "prompt": "test"})
|
||||
tc = {
|
||||
"index": 0,
|
||||
"id": "call_task",
|
||||
"function": {"name": "Task", "arguments": args},
|
||||
}
|
||||
events = list(lmstudio_provider._process_tool_call(tc, sse))
|
||||
event_text = "".join(events)
|
||||
assert "false" in event_text.lower()
|
||||
|
||||
def test_task_tool_chunked_args_forces_background_false(self, lmstudio_provider):
|
||||
"""Chunked Task args are buffered until valid JSON, then forced to false."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc1 = {
|
||||
"index": 0,
|
||||
"id": "call_task_chunked",
|
||||
"function": {"name": "Task", "arguments": '{"run_in_background": true,'},
|
||||
}
|
||||
tc2 = {
|
||||
"index": 0,
|
||||
"id": "call_task_chunked",
|
||||
"function": {"name": None, "arguments": ' "prompt": "test"}'},
|
||||
}
|
||||
|
||||
events1 = list(lmstudio_provider._process_tool_call(tc1, sse))
|
||||
assert len(events1) > 0
|
||||
assert "false" not in "".join(events1).lower()
|
||||
|
||||
events2 = list(lmstudio_provider._process_tool_call(tc2, sse))
|
||||
event_text = "".join(events1 + events2)
|
||||
assert "false" in event_text.lower()
|
||||
|
||||
def test_task_tool_invalid_json_logs_warning_on_flush(
|
||||
self, lmstudio_provider, caplog
|
||||
):
|
||||
"""Invalid JSON args for Task tool emits {} on flush and logs a warning."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc = {
|
||||
"index": 0,
|
||||
"id": "call_task2",
|
||||
"function": {"name": "Task", "arguments": "not json"},
|
||||
}
|
||||
events = list(lmstudio_provider._process_tool_call(tc, sse))
|
||||
assert len(events) > 0
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
flushed = list(lmstudio_provider._flush_task_arg_buffers(sse))
|
||||
assert len(flushed) > 0
|
||||
assert "{}" in "".join(flushed)
|
||||
assert any(
|
||||
"LMSTUDIO_INTERCEPT: Task args invalid JSON" in r.message
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
def test_negative_tool_index_fallback(self, lmstudio_provider):
|
||||
"""tc_index < 0 uses len(tool_indices) as fallback."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc = {
|
||||
"index": -1,
|
||||
"id": "call_neg",
|
||||
"function": {"name": "test", "arguments": "{}"},
|
||||
}
|
||||
events = list(lmstudio_provider._process_tool_call(tc, sse))
|
||||
assert len(events) > 0
|
||||
|
||||
def test_tool_args_emitted_as_delta(self, lmstudio_provider):
|
||||
"""Arguments are emitted as input_json_delta events."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc = {
|
||||
"index": 0,
|
||||
"id": "call_args",
|
||||
"function": {"name": "grep", "arguments": '{"pattern": "test"}'},
|
||||
}
|
||||
events = list(lmstudio_provider._process_tool_call(tc, sse))
|
||||
event_text = "".join(events)
|
||||
assert "input_json_delta" in event_text
|
||||
|
||||
def test_stream_malformed_tool_args_chunked(self, lmstudio_provider):
|
||||
"""Chunked tool args that never form valid JSON are flushed with {}."""
|
||||
from providers.nvidia_nim.utils import SSEBuilder
|
||||
|
||||
sse = SSEBuilder("msg_test", "test-model")
|
||||
tc1 = {
|
||||
"index": 0,
|
||||
"id": "call_malformed",
|
||||
"function": {"name": "Task", "arguments": '{"broken":'},
|
||||
}
|
||||
tc2 = {
|
||||
"index": 0,
|
||||
"id": "call_malformed",
|
||||
"function": {"name": None, "arguments": " never valid }"},
|
||||
}
|
||||
|
||||
events1 = list(lmstudio_provider._process_tool_call(tc1, sse))
|
||||
events2 = list(lmstudio_provider._process_tool_call(tc2, sse))
|
||||
flushed = list(lmstudio_provider._flush_task_arg_buffers(sse))
|
||||
|
||||
event_text = "".join(events1 + events2 + flushed)
|
||||
assert "tool_use" in event_text
|
||||
assert "{}" in event_text
|
||||
|
||||
|
||||
# --- Request Body Edge Cases ---
|
||||
|
||||
|
||||
def test_build_request_body_max_tokens_default(lmstudio_provider):
|
||||
"""max_tokens=None or 0 uses LMSTUDIO_DEFAULT_MAX_TOKENS."""
|
||||
req = MockRequest(max_tokens=None)
|
||||
body = lmstudio_provider._build_request_body(req)
|
||||
assert body["max_tokens"] == LMSTUDIO_DEFAULT_MAX_TOKENS
|
||||
assert body["max_tokens"] == 81920
|
||||
|
||||
req2 = MockRequest(max_tokens=0)
|
||||
body2 = lmstudio_provider._build_request_body(req2)
|
||||
assert body2["max_tokens"] == LMSTUDIO_DEFAULT_MAX_TOKENS
|
||||
|
||||
|
||||
def test_build_request_body_stop_sequences(lmstudio_provider):
|
||||
"""stop_sequences non-empty adds stop key to body."""
|
||||
req = MockRequest(stop_sequences=["STOP", "END"])
|
||||
body = lmstudio_provider._build_request_body(req)
|
||||
assert body["stop"] == ["STOP", "END"]
|
||||
|
||||
|
||||
def test_build_request_body_tools_and_tool_choice(lmstudio_provider):
|
||||
"""tools and tool_choice non-empty add to body."""
|
||||
tool = MagicMock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "A test"
|
||||
tool.input_schema = {"type": "object"}
|
||||
req = MockRequest(tools=[tool], tool_choice="auto")
|
||||
body = lmstudio_provider._build_request_body(req)
|
||||
assert "tools" in body
|
||||
assert body["tool_choice"] == "auto"
|
||||
|
||||
|
||||
# --- Base URL Trailing Slash ---
|
||||
|
||||
|
||||
def test_init_base_url_strips_trailing_slash():
|
||||
"""Config with base_url trailing slash is stored without it."""
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url="http://localhost:1234/v1/",
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
nim_settings=NimSettings(),
|
||||
)
|
||||
with patch("providers.lmstudio.client.AsyncOpenAI") as mock_openai:
|
||||
provider = LMStudioProvider(config)
|
||||
assert provider._base_url == "http://localhost:1234/v1"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
import json
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from providers.open_router import OpenRouterProvider
|
||||
from providers.open_router.request import OPENROUTER_DEFAULT_MAX_TOKENS
|
||||
from providers.base import ProviderConfig
|
||||
from config.nim import NimSettings
|
||||
|
||||
|
|
@ -94,6 +95,14 @@ def test_build_request_body_base_url_and_model(open_router_provider):
|
|||
assert body["model"] == "stepfun/step-3.5-flash:free"
|
||||
|
||||
|
||||
def test_build_request_body_default_max_tokens(open_router_provider):
|
||||
"""max_tokens=None uses OPENROUTER_DEFAULT_MAX_TOKENS (81920)."""
|
||||
req = MockRequest(max_tokens=None)
|
||||
body = open_router_provider._build_request_body(req)
|
||||
assert body["max_tokens"] == OPENROUTER_DEFAULT_MAX_TOKENS
|
||||
assert body["max_tokens"] == 81920
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_text(open_router_provider):
|
||||
"""Test streaming text response."""
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1,6 +1,6 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
requires-python = "==3.14.2"
|
||||
requires-python = ">=3.14.2"
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue