free-claude-code/tests/providers/test_lmstudio.py
2026-04-24 00:26:15 -07:00

353 lines
10 KiB
Python

"""Tests for LM Studio native Anthropic provider."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from providers.base import ProviderConfig
from providers.lmstudio import LMStudioProvider
class MockMessage:
def __init__(self, role, content):
self.role = role
self.content = content
class MockRequest:
def __init__(self, **kwargs):
self.model = "lmstudio-community/qwen2.5-7b-instruct"
self.messages = [MockMessage("user", "Hello")]
self.max_tokens = 100
self.temperature = 0.5
self.top_p = 0.9
self.system = "System prompt"
self.stop_sequences = None
self.tools = []
self.extra_body = {}
self.thinking = MagicMock()
self.thinking.enabled = True
for k, v in kwargs.items():
setattr(self, k, v)
def model_dump(self, exclude_none=True):
return {
"model": self.model,
"messages": [{"role": m.role, "content": m.content} for m in self.messages],
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"extra_body": self.extra_body,
"thinking": {"enabled": self.thinking.enabled} if self.thinking else None,
}
@pytest.fixture
def lmstudio_config():
return ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1",
rate_limit=10,
rate_window=60,
)
@pytest.fixture(autouse=True)
def mock_rate_limiter():
"""Mock the global rate limiter to prevent waiting."""
with patch("providers.lmstudio.client.GlobalRateLimiter") as mock:
instance = mock.get_instance.return_value
instance.wait_if_blocked = AsyncMock(return_value=False)
async def _passthrough(fn, *args, **kwargs):
return await fn(*args, **kwargs)
instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
yield instance
@pytest.fixture
def lmstudio_provider(lmstudio_config):
return LMStudioProvider(lmstudio_config)
def test_init(lmstudio_config):
"""Test provider initialization."""
with patch("httpx.AsyncClient"):
provider = LMStudioProvider(lmstudio_config)
assert provider._base_url == "http://localhost:1234/v1"
assert provider._provider_name == "LMSTUDIO"
def test_init_uses_configurable_timeouts():
"""Test that provider passes configurable read/write/connect timeouts to client."""
config = ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1",
http_read_timeout=600.0,
http_write_timeout=15.0,
http_connect_timeout=5.0,
)
with patch("httpx.AsyncClient") as mock_client:
LMStudioProvider(config)
call_kwargs = mock_client.call_args[1]
timeout = call_kwargs["timeout"]
assert timeout.read == 600.0
assert timeout.write == 15.0
assert timeout.connect == 5.0
def test_init_base_url_strips_trailing_slash():
"""Config with base_url trailing slash is stored without it."""
config = ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1/",
rate_limit=10,
rate_window=60,
)
with patch("httpx.AsyncClient"):
provider = LMStudioProvider(config)
assert provider._base_url == "http://localhost:1234/v1"
@pytest.mark.asyncio
async def test_stream_response_omits_thinking_when_globally_disabled(lmstudio_config):
provider = LMStudioProvider(
lmstudio_config.model_copy(update={"enable_thinking": False})
)
req = MockRequest()
mock_response = MagicMock()
mock_response.status_code = 200
async def empty_aiter():
if False:
yield ""
mock_response.aiter_lines = empty_aiter
with (
patch.object(provider._client, "build_request") as mock_build,
patch.object(
provider._client,
"send",
new_callable=AsyncMock,
return_value=mock_response,
),
):
[e async for e in provider.stream_response(req)]
_, kwargs = mock_build.call_args
assert "thinking" not in kwargs["json"]
@pytest.mark.asyncio
async def test_stream_response_omits_thinking_when_request_disables_it(
lmstudio_provider,
):
req = MockRequest()
req.thinking.enabled = False
mock_response = MagicMock()
mock_response.status_code = 200
async def empty_aiter():
if False:
yield ""
mock_response.aiter_lines = empty_aiter
with (
patch.object(lmstudio_provider._client, "build_request") as mock_build,
patch.object(
lmstudio_provider._client,
"send",
new_callable=AsyncMock,
return_value=mock_response,
),
):
[e async for e in lmstudio_provider.stream_response(req)]
_, kwargs = mock_build.call_args
assert "thinking" not in kwargs["json"]
@pytest.mark.asyncio
async def test_stream_response(lmstudio_provider):
"""Test streaming native Anthropic response."""
req = MockRequest()
mock_response = MagicMock()
mock_response.status_code = 200
async def mock_aiter_lines():
yield "event: message_start"
yield 'data: {"type":"message_start","message":{}}'
yield ""
yield "event: content_block_delta"
yield 'data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello World"}}'
yield ""
yield "event: message_stop"
yield 'data: {"type":"message_stop"}'
yield ""
mock_response.aiter_lines = mock_aiter_lines
with (
patch.object(
lmstudio_provider._client, "build_request", return_value=MagicMock()
) as mock_build,
patch.object(
lmstudio_provider._client,
"send",
new_callable=AsyncMock,
return_value=mock_response,
),
):
events = [e async for e in lmstudio_provider.stream_response(req)]
# Verify request construction
mock_build.assert_called_once()
args, kwargs = mock_build.call_args
assert args[0] == "POST"
assert args[1] == "/messages"
assert kwargs["json"]["model"] == "lmstudio-community/qwen2.5-7b-instruct"
# Verify internal fields are popped
assert "extra_body" not in kwargs["json"]
assert kwargs["json"]["max_tokens"] == 100
# Verify internal ThinkingConfig is mapped to Anthropic API format
assert kwargs["json"]["thinking"] == {"type": "enabled"}
# Verify events yielded correctly
assert len(events) == 9
assert events[0] == "event: message_start\n"
assert events[1] == 'data: {"type":"message_start","message":{}}\n'
@pytest.mark.asyncio
async def test_stream_response_adds_max_tokens_if_missing(lmstudio_provider):
"""Fallback max_tokens to 81920 if not present."""
req = MockRequest()
mock_response = MagicMock()
mock_response.status_code = 200
async def empty_aiter():
if False:
yield ""
mock_response.aiter_lines = empty_aiter
with (
patch.object(req, "model_dump", return_value={"model": "test"}),
patch.object(lmstudio_provider._client, "build_request") as mock_build,
patch.object(
lmstudio_provider._client,
"send",
new_callable=AsyncMock,
return_value=mock_response,
),
):
# Just run the generator to completion
[e async for e in lmstudio_provider.stream_response(req)]
_, kwargs = mock_build.call_args
assert kwargs["json"]["max_tokens"] == 81920
@pytest.mark.asyncio
async def test_stream_error_status_code(lmstudio_provider):
"""Non-200 status code raises an error that gets caught and yielded as an SSE API error."""
req = MockRequest()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.aread = AsyncMock(return_value=b"Internal Server Error")
mock_response.raise_for_status = MagicMock(
side_effect=httpx.HTTPStatusError(
"Internal Server Error", request=MagicMock(), response=mock_response
)
)
with (
patch.object(
lmstudio_provider._client, "build_request", return_value=MagicMock()
),
patch.object(
lmstudio_provider._client,
"send",
new_callable=AsyncMock,
return_value=mock_response,
),
):
events = [
e
async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID")
]
assert len(events) == 1
assert events[0].startswith("event: error\ndata: {")
assert "Internal Server Error" in events[0]
assert "TEST_ID" in events[0]
@pytest.mark.asyncio
async def test_stream_network_error(lmstudio_provider):
"""Network errors are caught and yielded as SSE API error events."""
req = MockRequest()
with (
patch.object(
lmstudio_provider._client, "build_request", return_value=MagicMock()
),
patch.object(
lmstudio_provider._client,
"send",
new_callable=AsyncMock,
side_effect=httpx.ConnectError("Connection refused"),
),
):
events = [
e
async for e in lmstudio_provider.stream_response(req, request_id="TEST_ID2")
]
assert len(events) == 1
assert events[0].startswith("event: error\ndata: {")
assert "Connection refused" in events[0]
assert "TEST_ID2" in events[0]
@pytest.mark.asyncio
async def test_stream_error_405_mentions_upstream_provider(lmstudio_provider):
req = MockRequest()
mock_response = MagicMock()
mock_response.status_code = 405
mock_response.aread = AsyncMock(return_value=b"Method Not Allowed")
mock_response.raise_for_status = MagicMock(
side_effect=httpx.HTTPStatusError(
"Method Not Allowed", request=MagicMock(), response=mock_response
)
)
with (
patch.object(
lmstudio_provider._client, "build_request", return_value=MagicMock()
),
patch.object(
lmstudio_provider._client,
"send",
new_callable=AsyncMock,
return_value=mock_response,
),
):
events = [
e async for e in lmstudio_provider.stream_response(req, request_id="REQ405")
]
assert (
"Upstream provider LMSTUDIO rejected the request method or endpoint (HTTP 405)."
in events[0]
)
assert "REQ405" in events[0]