free-claude-code/tests/providers/test_lmstudio.py
Shantanu Suryawanshi 24a5e4d968 Fixing sse stream
2026-02-18 21:31:28 -05:00

650 lines
21 KiB
Python

"""Tests for LM Studio provider."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from providers.base import ProviderConfig
from providers.lmstudio import LMStudioProvider
from providers.lmstudio.request import LMSTUDIO_DEFAULT_MAX_TOKENS
class AsyncStreamMock:
"""Async iterable mock that yields chunks then optionally raises."""
def __init__(self, chunks, error=None):
self._chunks = chunks
self._error = error
def __aiter__(self):
return self._aiter()
async def _aiter(self):
for chunk in self._chunks:
yield chunk
if self._error:
raise self._error
def _make_chunk(
content=None, finish_reason=None, tool_calls=None, reasoning_content=None
):
"""Create a mock streaming chunk."""
delta = MagicMock()
delta.content = content
delta.tool_calls = tool_calls
delta.reasoning_content = reasoning_content if reasoning_content else None
choice = MagicMock()
choice.delta = delta
choice.finish_reason = finish_reason
chunk = MagicMock()
chunk.choices = [choice]
chunk.usage = None
return chunk
def _make_request(model="test-model", **kwargs):
"""Create a mock request with all fields build_request_body needs."""
req = MagicMock()
req.model = model
req.messages = [MagicMock(role="user", content="Hello")]
req.system = None
req.max_tokens = 100
req.temperature = None
req.top_p = None
req.stop_sequences = None
req.tools = None
req.tool_choice = None
req.thinking = MagicMock(enabled=True)
for k, v in kwargs.items():
setattr(req, k, v)
return req
async def _collect_stream(provider, request):
"""Collect all SSE events from a stream."""
return [e async for e in provider.stream_response(request)]
class MockMessage:
def __init__(self, role, content):
self.role = role
self.content = content
class MockRequest:
def __init__(self, **kwargs):
self.model = "lmstudio-community/qwen2.5-7b-instruct"
self.messages = [MockMessage("user", "Hello")]
self.max_tokens = 100
self.temperature = 0.5
self.top_p = 0.9
self.system = "System prompt"
self.stop_sequences = None
self.tools = []
self.extra_body = {}
self.thinking = MagicMock()
self.thinking.enabled = True
for k, v in kwargs.items():
setattr(self, k, v)
@pytest.fixture
def lmstudio_config():
return ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1",
rate_limit=10,
rate_window=60,
)
@pytest.fixture(autouse=True)
def mock_rate_limiter():
"""Mock the global rate limiter to prevent waiting."""
with patch("providers.openai_compat.GlobalRateLimiter") as mock:
instance = mock.get_instance.return_value
instance.wait_if_blocked = AsyncMock(return_value=False)
async def _passthrough(fn, *args, **kwargs):
return await fn(*args, **kwargs)
instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
yield instance
@pytest.fixture
def lmstudio_provider(lmstudio_config):
return LMStudioProvider(lmstudio_config)
def test_init(lmstudio_config):
"""Test provider initialization."""
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
provider = LMStudioProvider(lmstudio_config)
assert provider._api_key == "lm-studio"
assert provider._base_url == "http://localhost:1234/v1"
mock_openai.assert_called_once()
def test_init_with_empty_api_key():
"""Provider uses lm-studio placeholder when api_key is empty."""
config = ProviderConfig(
api_key="",
base_url="http://localhost:1234/v1",
rate_limit=10,
rate_window=60,
)
with patch("providers.openai_compat.AsyncOpenAI"):
provider = LMStudioProvider(config)
assert provider._api_key == "lm-studio"
def test_init_uses_configurable_timeouts():
"""Test that provider passes configurable read/write/connect timeouts to client."""
config = ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1",
http_read_timeout=600.0,
http_write_timeout=15.0,
http_connect_timeout=5.0,
)
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
LMStudioProvider(config)
call_kwargs = mock_openai.call_args[1]
timeout = call_kwargs["timeout"]
assert timeout.read == 600.0
assert timeout.write == 15.0
assert timeout.connect == 5.0
def test_build_request_body_no_extra_body(lmstudio_provider):
"""LM Studio request body does NOT include extra_body/reasoning."""
req = MockRequest()
body = lmstudio_provider._build_request_body(req)
assert body["model"] == "lmstudio-community/qwen2.5-7b-instruct"
assert body["temperature"] == 0.5
assert len(body["messages"]) == 2 # System + User
assert body["messages"][0]["role"] == "system"
assert body["messages"][0]["content"] == "System prompt"
assert "extra_body" not in body
def test_build_request_body_base_url_and_model(lmstudio_provider):
"""Base URL and model are correct in provider config."""
assert lmstudio_provider._base_url == "http://localhost:1234/v1"
req = MockRequest(model="lmstudio-community/qwen2.5-7b-instruct")
body = lmstudio_provider._build_request_body(req)
assert body["model"] == "lmstudio-community/qwen2.5-7b-instruct"
@pytest.mark.asyncio
async def test_stream_response_text(lmstudio_provider):
"""Test streaming text response."""
req = MockRequest()
mock_chunk1 = MagicMock()
mock_chunk1.choices = [
MagicMock(
delta=MagicMock(content="Hello", reasoning_content=None),
finish_reason=None,
)
]
mock_chunk1.usage = None
mock_chunk2 = MagicMock()
mock_chunk2.choices = [
MagicMock(
delta=MagicMock(content=" World", reasoning_content=None),
finish_reason="stop",
)
]
mock_chunk2.usage = MagicMock(completion_tokens=10)
async def mock_stream():
yield mock_chunk1
yield mock_chunk2
with patch.object(
lmstudio_provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_stream()
events = [e async for e in lmstudio_provider.stream_response(req)]
assert len(events) > 0
assert "event: message_start" in events[0]
text_content = ""
for e in events:
if "event: content_block_delta" in e and '"text_delta"' in e:
for line in e.splitlines():
if line.startswith("data: "):
data = json.loads(line[6:])
if "delta" in data and "text" in data["delta"]:
text_content += data["delta"]["text"]
assert "Hello World" in text_content
@pytest.mark.asyncio
async def test_stream_response_reasoning_content(lmstudio_provider):
"""Test streaming with reasoning_content delta (if LM Studio adds support)."""
req = MockRequest()
mock_chunk = MagicMock()
mock_chunk.choices = [
MagicMock(
delta=MagicMock(content=None, reasoning_content="Thinking..."),
finish_reason=None,
)
]
mock_chunk.usage = None
async def mock_stream():
yield mock_chunk
with patch.object(
lmstudio_provider._client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_stream()
events = [e async for e in lmstudio_provider.stream_response(req)]
found_thinking = False
for e in events:
if (
"event: content_block_delta" in e
and '"thinking_delta"' in e
and "Thinking..." in e
):
found_thinking = True
assert found_thinking
# --- Stream Error Handling ---
class TestLMStudioStreamingExceptionHandling:
"""Tests for error paths during stream_response."""
@pytest.mark.asyncio
async def test_api_error_emits_sse_error_event(self, lmstudio_provider):
"""When API raises during streaming, SSE error event is emitted."""
request = _make_request()
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
side_effect=RuntimeError("API failed"),
):
events = await _collect_stream(lmstudio_provider, request)
event_text = "".join(events)
assert "message_start" in event_text
assert "API failed" in event_text
assert "message_stop" in event_text
@pytest.mark.asyncio
async def test_error_after_partial_content(self, lmstudio_provider):
"""Error after partial content: blocks closed, error emitted."""
request = _make_request()
chunk1 = _make_chunk(content="Hello ")
stream_mock = AsyncStreamMock(
[chunk1], error=ConnectionResetError("Connection lost")
)
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=stream_mock,
):
events = await _collect_stream(lmstudio_provider, request)
event_text = "".join(events)
assert "Hello" in event_text
assert "Connection lost" in event_text
assert "message_stop" in event_text
@pytest.mark.asyncio
async def test_empty_response_gets_space(self, lmstudio_provider):
"""Empty response with no text/tools gets a single space text block."""
request = _make_request()
empty_chunk = _make_chunk(finish_reason="stop")
stream_mock = AsyncStreamMock([empty_chunk])
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=stream_mock,
):
events = await _collect_stream(lmstudio_provider, request)
event_text = "".join(events)
assert '"text_delta"' in event_text
assert "message_stop" in event_text
# --- Stream Chunk Edge Cases ---
class TestLMStudioStreamChunkEdgeCases:
"""Tests for edge cases in stream chunk handling."""
@pytest.mark.asyncio
async def test_stream_chunk_with_empty_choices_skipped(self, lmstudio_provider):
"""Chunk with choices=[] is skipped without crashing."""
request = _make_request()
empty_choices_chunk = MagicMock()
empty_choices_chunk.choices = []
empty_choices_chunk.usage = None
finish_chunk = _make_chunk(finish_reason="stop")
stream_mock = AsyncStreamMock([empty_choices_chunk, finish_chunk])
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=stream_mock,
):
events = await _collect_stream(lmstudio_provider, request)
event_text = "".join(events)
assert "message_start" in event_text
assert "message_stop" in event_text
@pytest.mark.asyncio
async def test_stream_chunk_with_none_delta_handled(self, lmstudio_provider):
"""Chunk with choice.delta=None is handled defensively."""
request = _make_request()
none_delta_chunk = MagicMock()
none_delta_chunk.usage = None
choice = MagicMock()
choice.delta = None
choice.finish_reason = None
none_delta_chunk.choices = [choice]
finish_chunk = _make_chunk(finish_reason="stop")
stream_mock = AsyncStreamMock([none_delta_chunk, finish_chunk])
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=stream_mock,
):
events = await _collect_stream(lmstudio_provider, request)
event_text = "".join(events)
assert "message_start" in event_text
assert "message_stop" in event_text
# --- Native Tool Calls ---
@pytest.mark.asyncio
async def test_stream_response_tool_call(lmstudio_provider):
"""Test streaming tool calls."""
request = _make_request()
mock_tc = MagicMock()
mock_tc.index = 0
mock_tc.id = "call_1"
mock_tc.function.name = "search"
mock_tc.function.arguments = '{"q": "test"}'
mock_chunk = MagicMock()
mock_chunk.choices = [
MagicMock(
delta=MagicMock(content=None, reasoning_content=None, tool_calls=[mock_tc]),
finish_reason=None,
)
]
mock_chunk.usage = None
async def mock_stream():
yield mock_chunk
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=mock_stream(),
):
events = [e async for e in lmstudio_provider.stream_response(request)]
starts = [
e for e in events if "event: content_block_start" in e and '"tool_use"' in e
]
assert len(starts) == 1
assert "search" in starts[0]
# --- Think Tag Parsing ---
@pytest.mark.asyncio
async def test_stream_response_think_tag_parsing(lmstudio_provider):
"""Thinking content via think tags is emitted as thinking blocks."""
request = _make_request()
chunk1 = _make_chunk(content="<think>reasoning</think>answer")
chunk2 = _make_chunk(finish_reason="stop")
stream_mock = AsyncStreamMock([chunk1, chunk2])
with patch.object(
lmstudio_provider._client.chat.completions,
"create",
new_callable=AsyncMock,
return_value=stream_mock,
):
events = await _collect_stream(lmstudio_provider, request)
event_text = "".join(events)
assert "thinking" in event_text
assert "reasoning" in event_text
assert "answer" in event_text
# --- _process_tool_call and _flush_task_arg_buffers ---
class TestLMStudioProcessToolCall:
"""Tests for _process_tool_call method."""
def test_tool_call_with_id(self, lmstudio_provider):
"""Tool call with id starts a tool block."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
"index": 0,
"id": "call_123",
"function": {"name": "search", "arguments": '{"q": "test"}'},
}
events = list(lmstudio_provider._process_tool_call(tc, sse))
event_text = "".join(events)
assert "tool_use" in event_text
assert "search" in event_text
assert "call_123" in event_text
def test_tool_call_without_id_generates_uuid(self, lmstudio_provider):
"""Tool call without id generates a uuid-based id."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
"index": 0,
"id": None,
"function": {"name": "test", "arguments": "{}"},
}
events = list(lmstudio_provider._process_tool_call(tc, sse))
event_text = "".join(events)
assert "tool_" in event_text
def test_task_tool_forces_background_false(self, lmstudio_provider):
"""Task tool with run_in_background=true is forced to false."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
args = json.dumps({"run_in_background": True, "prompt": "test"})
tc = {
"index": 0,
"id": "call_task",
"function": {"name": "Task", "arguments": args},
}
events = list(lmstudio_provider._process_tool_call(tc, sse))
event_text = "".join(events)
assert "false" in event_text.lower()
def test_task_tool_chunked_args_forces_background_false(self, lmstudio_provider):
"""Chunked Task args are buffered until valid JSON, then forced to false."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc1 = {
"index": 0,
"id": "call_task_chunked",
"function": {"name": "Task", "arguments": '{"run_in_background": true,'},
}
tc2 = {
"index": 0,
"id": "call_task_chunked",
"function": {"name": None, "arguments": ' "prompt": "test"}'},
}
events1 = list(lmstudio_provider._process_tool_call(tc1, sse))
assert len(events1) > 0
assert "false" not in "".join(events1).lower()
events2 = list(lmstudio_provider._process_tool_call(tc2, sse))
event_text = "".join(events1 + events2)
assert "false" in event_text.lower()
def test_task_tool_invalid_json_logs_warning_on_flush(
self, lmstudio_provider, caplog
):
"""Invalid JSON args for Task tool emits {} on flush and logs a warning."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
"index": 0,
"id": "call_task2",
"function": {"name": "Task", "arguments": "not json"},
}
events = list(lmstudio_provider._process_tool_call(tc, sse))
assert len(events) > 0
with caplog.at_level("WARNING"):
flushed = list(lmstudio_provider._flush_task_arg_buffers(sse))
assert len(flushed) > 0
assert "{}" in "".join(flushed)
assert any("Task args invalid JSON" in r.message for r in caplog.records)
def test_negative_tool_index_fallback(self, lmstudio_provider):
"""tc_index < 0 uses len(tool_indices) as fallback."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
"index": -1,
"id": "call_neg",
"function": {"name": "test", "arguments": "{}"},
}
events = list(lmstudio_provider._process_tool_call(tc, sse))
assert len(events) > 0
def test_tool_args_emitted_as_delta(self, lmstudio_provider):
"""Arguments are emitted as input_json_delta events."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc = {
"index": 0,
"id": "call_args",
"function": {"name": "grep", "arguments": '{"pattern": "test"}'},
}
events = list(lmstudio_provider._process_tool_call(tc, sse))
event_text = "".join(events)
assert "input_json_delta" in event_text
def test_stream_malformed_tool_args_chunked(self, lmstudio_provider):
"""Chunked tool args that never form valid JSON are flushed with {}."""
from providers.common import SSEBuilder
sse = SSEBuilder("msg_test", "test-model")
tc1 = {
"index": 0,
"id": "call_malformed",
"function": {"name": "Task", "arguments": '{"broken":'},
}
tc2 = {
"index": 0,
"id": "call_malformed",
"function": {"name": None, "arguments": " never valid }"},
}
events1 = list(lmstudio_provider._process_tool_call(tc1, sse))
events2 = list(lmstudio_provider._process_tool_call(tc2, sse))
flushed = list(lmstudio_provider._flush_task_arg_buffers(sse))
event_text = "".join(events1 + events2 + flushed)
assert "tool_use" in event_text
assert "{}" in event_text
# --- Request Body Edge Cases ---
def test_build_request_body_max_tokens_default(lmstudio_provider):
"""max_tokens=None or 0 uses LMSTUDIO_DEFAULT_MAX_TOKENS."""
req = MockRequest(max_tokens=None)
body = lmstudio_provider._build_request_body(req)
assert body["max_tokens"] == LMSTUDIO_DEFAULT_MAX_TOKENS
assert body["max_tokens"] == 81920
req2 = MockRequest(max_tokens=0)
body2 = lmstudio_provider._build_request_body(req2)
assert body2["max_tokens"] == LMSTUDIO_DEFAULT_MAX_TOKENS
def test_build_request_body_stop_sequences(lmstudio_provider):
"""stop_sequences non-empty adds stop key to body."""
req = MockRequest(stop_sequences=["STOP", "END"])
body = lmstudio_provider._build_request_body(req)
assert body["stop"] == ["STOP", "END"]
def test_build_request_body_tools_and_tool_choice(lmstudio_provider):
"""tools and tool_choice non-empty add to body."""
tool = MagicMock()
tool.name = "test_tool"
tool.description = "A test"
tool.input_schema = {"type": "object"}
req = MockRequest(tools=[tool], tool_choice="auto")
body = lmstudio_provider._build_request_body(req)
assert "tools" in body
assert body["tool_choice"] == "auto"
# --- Base URL Trailing Slash ---
def test_init_base_url_strips_trailing_slash():
"""Config with base_url trailing slash is stored without it."""
config = ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1/",
rate_limit=10,
rate_window=60,
)
with patch("providers.openai_compat.AsyncOpenAI"):
provider = LMStudioProvider(config)
assert provider._base_url == "http://localhost:1234/v1"