mirror of
https://github.com/supermemoryai/supermemory.git
synced 2026-04-30 20:49:56 +00:00
387 lines
12 KiB
Python
387 lines
12 KiB
Python
"""Tests for infinite_chat module."""
|
|
|
|
import os
|
|
import pytest
|
|
from typing import List
|
|
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
from ..src import (
|
|
SupermemoryOpenAI,
|
|
SupermemoryInfiniteChatConfigWithProviderName,
|
|
SupermemoryInfiniteChatConfigWithProviderUrl,
|
|
ProviderName,
|
|
)
|
|
|
|
|
|
# Test configuration
|
|
PROVIDERS: List[ProviderName] = [
|
|
"openai",
|
|
"anthropic",
|
|
"openrouter",
|
|
"deepinfra",
|
|
"groq",
|
|
"google",
|
|
"cloudflare",
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def test_api_key() -> str:
|
|
"""Get test Supermemory API key from environment."""
|
|
api_key = os.getenv("SUPERMEMORY_API_KEY")
|
|
if not api_key:
|
|
pytest.skip("SUPERMEMORY_API_KEY environment variable is required for tests")
|
|
return api_key
|
|
|
|
|
|
@pytest.fixture
|
|
def test_provider_api_key() -> str:
|
|
"""Get test provider API key from environment."""
|
|
api_key = os.getenv("PROVIDER_API_KEY")
|
|
if not api_key:
|
|
pytest.skip("PROVIDER_API_KEY environment variable is required for tests")
|
|
return api_key
|
|
|
|
|
|
@pytest.fixture
|
|
def test_provider_name() -> ProviderName:
|
|
"""Get test provider name from environment."""
|
|
provider_name = os.getenv("PROVIDER_NAME", "openai")
|
|
if provider_name not in PROVIDERS:
|
|
pytest.fail(f"Invalid provider name: {provider_name}")
|
|
return provider_name # type: ignore
|
|
|
|
|
|
@pytest.fixture
|
|
def test_provider_url() -> str:
|
|
"""Get test provider URL from environment."""
|
|
return os.getenv("PROVIDER_URL", "")
|
|
|
|
|
|
@pytest.fixture
|
|
def test_model_name() -> str:
|
|
"""Get test model name from environment."""
|
|
return os.getenv("MODEL_NAME", "gpt-4o-mini")
|
|
|
|
|
|
@pytest.fixture
|
|
def test_headers() -> dict:
|
|
"""Get test headers."""
|
|
return {"custom-header": "test-value"}
|
|
|
|
|
|
@pytest.fixture
|
|
def test_messages() -> List[List[ChatCompletionMessageParam]]:
|
|
"""Test message sets."""
|
|
return [
|
|
[{"role": "user", "content": "Hello!"}],
|
|
[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What is AI?"},
|
|
],
|
|
[
|
|
{"role": "user", "content": "Tell me a joke"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Why don't scientists trust atoms? Because they make up everything!",
|
|
},
|
|
{"role": "user", "content": "Tell me another one"},
|
|
],
|
|
]
|
|
|
|
|
|
class TestClientCreation:
|
|
"""Test client creation."""
|
|
|
|
def test_create_client_with_provider_name(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
test_headers: dict,
|
|
):
|
|
"""Test creating client with provider name configuration."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
headers=test_headers,
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
assert client is not None
|
|
assert client.chat is not None
|
|
|
|
def test_create_client_with_openai_provider(
|
|
self, test_api_key: str, test_provider_api_key: str, test_headers: dict
|
|
):
|
|
"""Test creating client with OpenAI provider configuration."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name="openai",
|
|
provider_api_key=test_provider_api_key,
|
|
headers=test_headers,
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
assert client is not None
|
|
|
|
def test_create_client_with_custom_provider_url(
|
|
self, test_api_key: str, test_provider_api_key: str, test_headers: dict
|
|
):
|
|
"""Test creating client with custom provider URL."""
|
|
custom_url = "https://custom-provider.com/v1"
|
|
config = SupermemoryInfiniteChatConfigWithProviderUrl(
|
|
provider_url=custom_url,
|
|
provider_api_key=test_provider_api_key,
|
|
headers=test_headers,
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
assert client is not None
|
|
|
|
|
|
class TestChatCompletions:
|
|
"""Test chat completions functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_chat_completion_simple_message(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
test_model_name: str,
|
|
test_messages: List[List[ChatCompletionMessageParam]],
|
|
):
|
|
"""Test creating chat completion with simple message."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
headers={},
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
result = await client.create_chat_completion(
|
|
model=test_model_name,
|
|
messages=test_messages[0], # "Hello!"
|
|
)
|
|
|
|
assert result is not None
|
|
assert hasattr(result, "choices")
|
|
assert len(result.choices) > 0
|
|
assert result.choices[0].message.content is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_convenience_method(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
test_model_name: str,
|
|
test_messages: List[List[ChatCompletionMessageParam]],
|
|
):
|
|
"""Test chat completion using convenience method."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
headers={},
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
result = await client.chat_completion(
|
|
messages=test_messages[1], # System + user messages
|
|
model=test_model_name,
|
|
temperature=0.7,
|
|
)
|
|
|
|
assert result is not None
|
|
assert hasattr(result, "choices")
|
|
assert len(result.choices) > 0
|
|
assert result.choices[0].message.content is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_conversation_history(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
test_model_name: str,
|
|
test_messages: List[List[ChatCompletionMessageParam]],
|
|
):
|
|
"""Test handling conversation history."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
headers={},
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
result = await client.chat_completion(
|
|
messages=test_messages[2], # Multi-turn conversation
|
|
model=test_model_name,
|
|
)
|
|
|
|
assert result is not None
|
|
assert hasattr(result, "choices")
|
|
assert len(result.choices) > 0
|
|
assert result.choices[0].message.content is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_headers(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
test_model_name: str,
|
|
):
|
|
"""Test working with custom headers."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
headers={"x-custom-header": "test-value"},
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
result = await client.chat_completion(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model=test_model_name,
|
|
)
|
|
|
|
assert result is not None
|
|
assert hasattr(result, "choices")
|
|
|
|
|
|
class TestConfigurationValidation:
|
|
"""Test configuration validation."""
|
|
|
|
def test_handle_empty_headers_object(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
):
|
|
"""Test handling empty headers object."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
headers={},
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
assert client is not None
|
|
|
|
def test_handle_configuration_without_headers(
|
|
self,
|
|
test_api_key: str,
|
|
test_provider_api_key: str,
|
|
test_provider_name: ProviderName,
|
|
):
|
|
"""Test handling configuration without headers."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name=test_provider_name,
|
|
provider_api_key=test_provider_api_key,
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
assert client is not None
|
|
|
|
def test_handle_different_api_keys(self):
|
|
"""Test handling different API keys."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name="openai",
|
|
provider_api_key="different-provider-key",
|
|
)
|
|
|
|
client = SupermemoryOpenAI("different-sm-key", config)
|
|
|
|
assert client is not None
|
|
|
|
|
|
class TestDisabledEndpoints:
|
|
"""Test that non-chat endpoints are disabled."""
|
|
|
|
def test_disabled_endpoints_throw_errors(
|
|
self, test_api_key: str, test_provider_api_key: str
|
|
):
|
|
"""Test that all disabled endpoints throw appropriate errors."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name="openai",
|
|
provider_api_key=test_provider_api_key,
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
# Test that all disabled endpoints throw appropriate errors
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.embeddings
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.fine_tuning
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.images
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.audio
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.models
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.moderations
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.files
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.batches
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.uploads
|
|
|
|
with pytest.raises(
|
|
RuntimeError, match="Supermemory only supports chat completions"
|
|
):
|
|
_ = client.beta
|
|
|
|
def test_chat_completions_still_work(
|
|
self, test_api_key: str, test_provider_api_key: str
|
|
):
|
|
"""Test that chat completions still work after disabling other endpoints."""
|
|
config = SupermemoryInfiniteChatConfigWithProviderName(
|
|
provider_name="openai",
|
|
provider_api_key=test_provider_api_key,
|
|
)
|
|
|
|
client = SupermemoryOpenAI(test_api_key, config)
|
|
|
|
# Chat completions should still be accessible
|
|
assert client.chat is not None
|
|
assert client.chat.completions is not None
|
|
assert callable(client.create_chat_completion)
|
|
assert callable(client.chat_completion)
|