mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 19:40:54 +00:00
158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
from fastapi.testclient import TestClient
|
|
from api.app import app
|
|
from api.dependencies import get_provider
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from providers.nvidia_nim import NvidiaNimProvider
|
|
|
|
# Mock provider
|
|
mock_provider = MagicMock(spec=NvidiaNimProvider)
|
|
mock_provider.complete = AsyncMock()
|
|
mock_provider.stream_response = AsyncMock()
|
|
mock_provider.convert_response = MagicMock()
|
|
|
|
|
|
def override_get_provider():
|
|
return mock_provider
|
|
|
|
|
|
app.dependency_overrides[get_provider] = override_get_provider
|
|
client = TestClient(app)
|
|
|
|
|
|
def test_root():
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "ok"
|
|
|
|
|
|
def test_health():
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "healthy"
|
|
|
|
|
|
def test_create_message_non_stream():
|
|
mock_provider.complete.return_value = {"id": "123", "choices": []}
|
|
mock_provider.convert_response.return_value = {
|
|
"id": "msg_123",
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": "test-model",
|
|
"content": [{"type": "text", "text": "Hello"}],
|
|
"usage": {"input_tokens": 10, "output_tokens": 5},
|
|
}
|
|
|
|
payload = {
|
|
"model": "claude-3-sonnet",
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 100,
|
|
"stream": False,
|
|
}
|
|
|
|
response = client.post("/v1/messages", json=payload)
|
|
assert response.status_code == 200
|
|
assert response.json()["content"][0]["text"] == "Hello"
|
|
mock_provider.complete.assert_called_once()
|
|
|
|
|
|
def test_model_mapping():
|
|
# Test Haiku mapping
|
|
payload_haiku = {
|
|
"model": "claude-3-haiku-20240307",
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 100,
|
|
}
|
|
client.post("/v1/messages", json=payload_haiku)
|
|
args, _ = mock_provider.complete.call_args
|
|
assert args[0].model != "claude-3-haiku-20240307"
|
|
assert args[0].original_model == "claude-3-haiku-20240307"
|
|
|
|
|
|
def test_error_fallbacks():
|
|
from providers.exceptions import (
|
|
AuthenticationError,
|
|
RateLimitError,
|
|
OverloadedError,
|
|
)
|
|
|
|
# 1. Authentication Error (401)
|
|
mock_provider.complete.side_effect = AuthenticationError("Invalid Key")
|
|
response = client.post(
|
|
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
|
|
)
|
|
assert response.status_code == 401
|
|
assert response.json()["error"]["type"] == "authentication_error"
|
|
|
|
# 2. Rate Limit (429)
|
|
mock_provider.complete.side_effect = RateLimitError("Too Many Requests")
|
|
response = client.post(
|
|
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
|
|
)
|
|
assert response.status_code == 429
|
|
assert response.json()["error"]["type"] == "rate_limit_error"
|
|
|
|
# 3. Overloaded (529)
|
|
mock_provider.complete.side_effect = OverloadedError("Server Overloaded")
|
|
response = client.post(
|
|
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
|
|
)
|
|
assert response.status_code == 529
|
|
assert response.json()["error"]["type"] == "overloaded_error"
|
|
|
|
# Reset side_effect for subsequent tests
|
|
mock_provider.complete.side_effect = None
|
|
|
|
|
|
def test_generic_exception_returns_500():
|
|
"""Non-ProviderError exceptions are caught and returned as HTTPException(500)."""
|
|
mock_provider.complete.side_effect = RuntimeError("unexpected crash")
|
|
response = client.post(
|
|
"/v1/messages",
|
|
json={
|
|
"model": "test",
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 10,
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert response.status_code == 500
|
|
mock_provider.complete.side_effect = None
|
|
|
|
|
|
def test_generic_exception_with_status_code():
|
|
"""Exception with status_code attribute uses that status."""
|
|
exc = RuntimeError("bad gateway")
|
|
exc.status_code = 502
|
|
mock_provider.complete.side_effect = exc
|
|
response = client.post(
|
|
"/v1/messages",
|
|
json={
|
|
"model": "test",
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"max_tokens": 10,
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert response.status_code == 502
|
|
mock_provider.complete.side_effect = None
|
|
|
|
|
|
def test_count_tokens_endpoint():
|
|
"""count_tokens endpoint returns token count."""
|
|
response = client.post(
|
|
"/v1/messages/count_tokens",
|
|
json={"model": "test", "messages": [{"role": "user", "content": "Hello"}]},
|
|
)
|
|
assert response.status_code == 200
|
|
assert "input_tokens" in response.json()
|
|
|
|
|
|
def test_stop_endpoint_no_handler_no_cli_503():
|
|
"""POST /stop without handler or cli_manager returns 503."""
|
|
# Ensure no handler or cli_manager on app state
|
|
if hasattr(app.state, "message_handler"):
|
|
delattr(app.state, "message_handler")
|
|
if hasattr(app.state, "cli_manager"):
|
|
delattr(app.state, "cli_manager")
|
|
response = client.post("/stop")
|
|
assert response.status_code == 503
|