supermemory/packages/openai-sdk-python/tests/test_infinite_chat.py
CodeWithShreyans 3a0e264b7e
Some checks are pending
Publish AI SDK / publish (push) Waiting to run
feat: openai js and python sdk utilities (#389)
needs testing
2025-08-27 23:34:49 +00:00

387 lines
12 KiB
Python

"""Tests for infinite_chat module."""
import os
import pytest
from typing import List
from openai.types.chat import ChatCompletionMessageParam
from ..src import (
SupermemoryOpenAI,
SupermemoryInfiniteChatConfigWithProviderName,
SupermemoryInfiniteChatConfigWithProviderUrl,
ProviderName,
)
# Test configuration
PROVIDERS: List[ProviderName] = [
"openai",
"anthropic",
"openrouter",
"deepinfra",
"groq",
"google",
"cloudflare",
]
@pytest.fixture
def test_api_key() -> str:
"""Get test Supermemory API key from environment."""
api_key = os.getenv("SUPERMEMORY_API_KEY")
if not api_key:
pytest.skip("SUPERMEMORY_API_KEY environment variable is required for tests")
return api_key
@pytest.fixture
def test_provider_api_key() -> str:
"""Get test provider API key from environment."""
api_key = os.getenv("PROVIDER_API_KEY")
if not api_key:
pytest.skip("PROVIDER_API_KEY environment variable is required for tests")
return api_key
@pytest.fixture
def test_provider_name() -> ProviderName:
"""Get test provider name from environment."""
provider_name = os.getenv("PROVIDER_NAME", "openai")
if provider_name not in PROVIDERS:
pytest.fail(f"Invalid provider name: {provider_name}")
return provider_name # type: ignore
@pytest.fixture
def test_provider_url() -> str:
"""Get test provider URL from environment."""
return os.getenv("PROVIDER_URL", "")
@pytest.fixture
def test_model_name() -> str:
"""Get test model name from environment."""
return os.getenv("MODEL_NAME", "gpt-4o-mini")
@pytest.fixture
def test_headers() -> dict:
"""Get test headers."""
return {"custom-header": "test-value"}
@pytest.fixture
def test_messages() -> List[List[ChatCompletionMessageParam]]:
"""Test message sets."""
return [
[{"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is AI?"},
],
[
{"role": "user", "content": "Tell me a joke"},
{
"role": "assistant",
"content": "Why don't scientists trust atoms? Because they make up everything!",
},
{"role": "user", "content": "Tell me another one"},
],
]
class TestClientCreation:
"""Test client creation."""
def test_create_client_with_provider_name(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
test_headers: dict,
):
"""Test creating client with provider name configuration."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
headers=test_headers,
)
client = SupermemoryOpenAI(test_api_key, config)
assert client is not None
assert client.chat is not None
def test_create_client_with_openai_provider(
self, test_api_key: str, test_provider_api_key: str, test_headers: dict
):
"""Test creating client with OpenAI provider configuration."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name="openai",
provider_api_key=test_provider_api_key,
headers=test_headers,
)
client = SupermemoryOpenAI(test_api_key, config)
assert client is not None
def test_create_client_with_custom_provider_url(
self, test_api_key: str, test_provider_api_key: str, test_headers: dict
):
"""Test creating client with custom provider URL."""
custom_url = "https://custom-provider.com/v1"
config = SupermemoryInfiniteChatConfigWithProviderUrl(
provider_url=custom_url,
provider_api_key=test_provider_api_key,
headers=test_headers,
)
client = SupermemoryOpenAI(test_api_key, config)
assert client is not None
class TestChatCompletions:
"""Test chat completions functionality."""
@pytest.mark.asyncio
async def test_create_chat_completion_simple_message(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
test_model_name: str,
test_messages: List[List[ChatCompletionMessageParam]],
):
"""Test creating chat completion with simple message."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
headers={},
)
client = SupermemoryOpenAI(test_api_key, config)
result = await client.create_chat_completion(
model=test_model_name,
messages=test_messages[0], # "Hello!"
)
assert result is not None
assert hasattr(result, "choices")
assert len(result.choices) > 0
assert result.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_chat_completion_convenience_method(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
test_model_name: str,
test_messages: List[List[ChatCompletionMessageParam]],
):
"""Test chat completion using convenience method."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
headers={},
)
client = SupermemoryOpenAI(test_api_key, config)
result = await client.chat_completion(
messages=test_messages[1], # System + user messages
model=test_model_name,
temperature=0.7,
)
assert result is not None
assert hasattr(result, "choices")
assert len(result.choices) > 0
assert result.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_handle_conversation_history(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
test_model_name: str,
test_messages: List[List[ChatCompletionMessageParam]],
):
"""Test handling conversation history."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
headers={},
)
client = SupermemoryOpenAI(test_api_key, config)
result = await client.chat_completion(
messages=test_messages[2], # Multi-turn conversation
model=test_model_name,
)
assert result is not None
assert hasattr(result, "choices")
assert len(result.choices) > 0
assert result.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_custom_headers(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
test_model_name: str,
):
"""Test working with custom headers."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
headers={"x-custom-header": "test-value"},
)
client = SupermemoryOpenAI(test_api_key, config)
result = await client.chat_completion(
messages=[{"role": "user", "content": "Hello"}],
model=test_model_name,
)
assert result is not None
assert hasattr(result, "choices")
class TestConfigurationValidation:
"""Test configuration validation."""
def test_handle_empty_headers_object(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
):
"""Test handling empty headers object."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
headers={},
)
client = SupermemoryOpenAI(test_api_key, config)
assert client is not None
def test_handle_configuration_without_headers(
self,
test_api_key: str,
test_provider_api_key: str,
test_provider_name: ProviderName,
):
"""Test handling configuration without headers."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name=test_provider_name,
provider_api_key=test_provider_api_key,
)
client = SupermemoryOpenAI(test_api_key, config)
assert client is not None
def test_handle_different_api_keys(self):
"""Test handling different API keys."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name="openai",
provider_api_key="different-provider-key",
)
client = SupermemoryOpenAI("different-sm-key", config)
assert client is not None
class TestDisabledEndpoints:
"""Test that non-chat endpoints are disabled."""
def test_disabled_endpoints_throw_errors(
self, test_api_key: str, test_provider_api_key: str
):
"""Test that all disabled endpoints throw appropriate errors."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name="openai",
provider_api_key=test_provider_api_key,
)
client = SupermemoryOpenAI(test_api_key, config)
# Test that all disabled endpoints throw appropriate errors
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.embeddings
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.fine_tuning
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.images
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.audio
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.models
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.moderations
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.files
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.batches
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.uploads
with pytest.raises(
RuntimeError, match="Supermemory only supports chat completions"
):
_ = client.beta
def test_chat_completions_still_work(
self, test_api_key: str, test_provider_api_key: str
):
"""Test that chat completions still work after disabling other endpoints."""
config = SupermemoryInfiniteChatConfigWithProviderName(
provider_name="openai",
provider_api_key=test_provider_api_key,
)
client = SupermemoryOpenAI(test_api_key, config)
# Chat completions should still be accessible
assert client.chat is not None
assert client.chat.completions is not None
assert callable(client.create_chat_completion)
assert callable(client.chat_completion)