supermemory/packages/openai-sdk-python/src/supermemory_openai/middleware.py
2026-04-25 01:43:14 +00:00

670 lines
23 KiB
Python

"""Supermemory middleware for OpenAI clients."""
import asyncio
import inspect
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union, cast
import supermemory
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
)
from .exceptions import (
SupermemoryAPIError,
SupermemoryConfigurationError,
SupermemoryMemoryOperationError,
SupermemoryNetworkError,
)
from .utils import (
Logger,
convert_profile_to_markdown,
create_logger,
deduplicate_memories,
get_conversation_content,
get_last_user_message,
)
@dataclass
class OpenAIMiddlewareOptions:
"""Configuration options for OpenAI middleware."""
container_tag: str # Required: identifies the user/container
custom_id: str # Required: groups messages into the same document
verbose: bool = False
mode: Literal["profile", "query", "full"] = "profile"
add_memory: Literal["always", "never"] = "always"
class SupermemoryProfileSearch:
"""Type for Supermemory profile search response."""
def __init__(self, data: dict[str, Any]):
self.profile: dict[str, Any] = data.get("profile", {})
self.search_results: dict[str, Any] = data.get("searchResults", {})
async def supermemory_profile_search(
container_tag: str,
query_text: str,
api_key: str,
) -> SupermemoryProfileSearch:
"""Search for memories using the SuperMemory profile API."""
payload = {
"containerTag": container_tag,
}
if query_text:
payload["q"] = query_text
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.supermemory.ai/v4/profile",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
json=payload,
) as response:
if not response.ok:
error_text = await response.text()
raise SupermemoryAPIError(
"Supermemory profile search failed",
status_code=response.status,
response_text=error_text,
)
data = await response.json()
return SupermemoryProfileSearch(data)
except ImportError:
# Fallback to requests if aiohttp not available
import requests
response = requests.post(
"https://api.supermemory.ai/v4/profile",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
json=payload,
)
if not response.ok:
raise SupermemoryAPIError(
"Supermemory profile search failed",
status_code=response.status_code,
response_text=response.text,
)
return SupermemoryProfileSearch(response.json())
async def add_system_prompt(
messages: list[ChatCompletionMessageParam],
container_tag: str,
logger: Logger,
mode: Literal["profile", "query", "full"],
api_key: str,
) -> list[ChatCompletionMessageParam]:
"""Add memory-enhanced system prompts to chat completion messages."""
system_prompt_exists = any(msg.get("role") == "system" for msg in messages)
query_text = get_last_user_message(messages) if mode != "profile" else ""
memories_response = await supermemory_profile_search(
container_tag, query_text, api_key
)
profile = memories_response.profile or {}
search_results_data = memories_response.search_results or {}
memory_count_static = len(profile.get("static", []))
memory_count_dynamic = len(profile.get("dynamic", []))
memory_count_search = len(search_results_data.get("results", []))
logger.info(
"Memory search completed",
{
"container_tag": container_tag,
"memory_count_static": memory_count_static,
"memory_count_dynamic": memory_count_dynamic,
"query_text": query_text[:100] + ("..." if len(query_text) > 100 else ""),
"mode": mode,
},
)
deduplicated = deduplicate_memories(
static=profile.get("static", []),
dynamic=profile.get("dynamic", []),
search_results=search_results_data.get("results", []),
)
logger.debug(
"Memory deduplication completed",
{
"static": {
"original": memory_count_static,
"deduplicated": len(deduplicated.static),
},
"dynamic": {
"original": memory_count_dynamic,
"deduplicated": len(deduplicated.dynamic),
},
"search_results": {
"original": memory_count_search,
"deduplicated": len(deduplicated.search_results),
},
},
)
profile_data = ""
if mode != "query":
profile_data = convert_profile_to_markdown(
{
"profile": {
"static": deduplicated.static,
"dynamic": deduplicated.dynamic,
},
"searchResults": {"results": []},
}
)
search_results_memories = ""
if mode != "profile" and deduplicated.search_results:
search_results_memories = (
"Search results for user's recent message: \n"
+ "\n".join(f"- {memory}" for memory in deduplicated.search_results)
)
memories = f"{profile_data}\n{search_results_memories}".strip()
if memories:
logger.debug(
"Memory content preview",
{
"content": memories,
"full_length": len(memories),
},
)
if system_prompt_exists:
logger.debug("Added memories to existing system prompt")
return [
{**msg, "content": f"{msg.get('content', '')} \n {memories}"}
if msg.get("role") == "system"
else msg
for msg in messages
]
logger.debug("System prompt does not exist, created system prompt with memories")
system_message: ChatCompletionSystemMessageParam = {
"role": "system",
"content": memories,
}
return [system_message] + messages
async def add_memory_tool(
client: supermemory.Supermemory,
container_tag: str,
content: str,
custom_id: Optional[str],
logger: Logger,
) -> None:
"""Add a new memory to the SuperMemory system."""
try:
add_params = {
"content": content,
"container_tags": [container_tag],
}
if custom_id is not None:
add_params["custom_id"] = custom_id
# Handle both sync and async supermemory clients
result = client.add(**add_params)
if inspect.isawaitable(result):
response = await result
else:
response = result
logger.info(
"Memory saved successfully",
{
"container_tag": container_tag,
"custom_id": custom_id,
"content_length": len(content),
"memory_id": response.id,
},
)
except (OSError, ConnectionError) as network_error:
logger.error(
"Network error while saving memory",
{"error": str(network_error)},
)
raise SupermemoryNetworkError(
"Failed to save memory due to network error", network_error
)
except Exception as error:
logger.error(
"Error saving memory",
{"error": str(error)},
)
raise SupermemoryMemoryOperationError("Failed to save memory", error)
class SupermemoryOpenAIWrapper:
"""Wrapper for OpenAI client with Supermemory middleware."""
def __init__(
self,
openai_client: Union[OpenAI, AsyncOpenAI],
options: OpenAIMiddlewareOptions,
):
self._client: Union[OpenAI, AsyncOpenAI] = openai_client
self._container_tag: str = options.container_tag
self._options: OpenAIMiddlewareOptions = options
self._logger: Logger = create_logger(self._options.verbose)
# Track background tasks to ensure they complete
self._background_tasks: set[asyncio.Task] = set()
if not hasattr(supermemory, "Supermemory"):
raise SupermemoryConfigurationError(
"supermemory package is required but not found",
ImportError("supermemory package not installed"),
)
api_key = self._get_api_key()
try:
self._supermemory_client: supermemory.Supermemory = supermemory.Supermemory(
api_key=api_key
)
except Exception as e:
raise SupermemoryConfigurationError(
f"Failed to initialize Supermemory client: {e}", e
)
# Wrap the chat completions create method
self._wrap_chat_completions()
def _get_api_key(self) -> str:
"""Get Supermemory API key from environment."""
import os
api_key = os.getenv("SUPERMEMORY_API_KEY")
if not api_key:
raise SupermemoryConfigurationError(
"SUPERMEMORY_API_KEY environment variable is required but not set"
)
return api_key
def _wrap_chat_completions(self) -> None:
"""Wrap the chat completions create method with memory injection."""
original_create = self._client.chat.completions.create
if asyncio.iscoroutinefunction(original_create):
async def create_with_memory(
**kwargs: Any,
) -> Any:
return await self._create_with_memory_async(original_create, **kwargs)
else:
def create_with_memory(
**kwargs: Any,
) -> Any:
return self._create_with_memory_sync(original_create, **kwargs)
# Replace the create method with our wrapper
setattr(self._client.chat.completions, "create", create_with_memory)
async def _create_with_memory_async(
self,
original_create: Any,
**kwargs: Any,
) -> Any:
"""Async version of create with memory injection."""
messages = kwargs.get("messages", [])
if self._options.add_memory == "always":
user_message = get_last_user_message(messages)
if user_message and user_message.strip():
content = (
get_conversation_content(messages)
if self._options.custom_id
else user_message
)
custom_id = (
f"conversation:{self._options.custom_id}"
if self._options.custom_id
else None
)
# Create background task for memory storage
task = asyncio.create_task(
add_memory_tool(
self._supermemory_client,
self._container_tag,
content,
custom_id,
self._logger,
)
)
# Track the task and set up cleanup
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Log any exceptions but don't fail the main request
def handle_task_exception(task_obj):
try:
if task_obj.exception() is not None:
exception = task_obj.exception()
if isinstance(
exception,
(SupermemoryNetworkError, SupermemoryAPIError),
):
self._logger.warn(
"Background memory storage failed",
{
"error": str(exception),
"type": type(exception).__name__,
},
)
else:
self._logger.error(
"Unexpected error in background memory storage",
{
"error": str(exception),
"type": type(exception).__name__,
},
)
except asyncio.CancelledError:
self._logger.debug("Memory storage task was cancelled")
task.add_done_callback(handle_task_exception)
if self._options.mode != "profile":
user_message = get_last_user_message(messages)
if not user_message:
self._logger.debug("No user message found, skipping memory search")
return await original_create(**kwargs)
self._logger.info(
"Starting memory search",
{
"container_tag": self._container_tag,
"conversation_id": self._options.custom_id,
"mode": self._options.mode,
},
)
enhanced_messages = await add_system_prompt(
messages,
self._container_tag,
self._logger,
self._options.mode,
self._get_api_key(),
)
kwargs["messages"] = enhanced_messages
return await original_create(**kwargs)
def _create_with_memory_sync(
self,
original_create: Any,
**kwargs: Any,
) -> Any:
"""Sync version of create with memory injection."""
# For sync clients, we implement a simplified version without background tasks
messages = kwargs.get("messages", [])
# Handle memory addition synchronously if needed
if self._options.add_memory == "always":
user_message = get_last_user_message(messages)
if user_message and user_message.strip():
content = (
get_conversation_content(messages)
if self._options.custom_id
else user_message
)
custom_id = (
f"conversation:{self._options.custom_id}"
if self._options.custom_id
else None
)
# Use asyncio.run() for the memory addition
try:
asyncio.run(
add_memory_tool(
self._supermemory_client,
self._container_tag,
content,
custom_id,
self._logger,
)
)
except RuntimeError as e:
if "cannot be called from a running event loop" in str(e):
# We're in an async context, log warning and skip memory saving
self._logger.warn(
"Cannot save memory in sync client from async context",
{"error": str(e)},
)
else:
raise
except SupermemoryNetworkError as e:
# Network errors are expected, log as warning
self._logger.warn("Network error saving memory", {"error": str(e)})
except (SupermemoryAPIError, SupermemoryMemoryOperationError) as e:
# API/memory errors are concerning, log as error
self._logger.error("Failed to save memory", {"error": str(e)})
except Exception as e:
# Unexpected errors should be investigated
self._logger.error(
"Unexpected error saving memory",
{"error": str(e), "type": type(e).__name__},
)
# Handle memory search and injection
if self._options.mode != "profile":
user_message = get_last_user_message(messages)
if not user_message:
self._logger.debug("No user message found, skipping memory search")
return original_create(**kwargs)
self._logger.info(
"Starting memory search",
{
"container_tag": self._container_tag,
"conversation_id": self._options.custom_id,
"mode": self._options.mode,
},
)
# Use asyncio.run() for memory search and injection
try:
enhanced_messages = asyncio.run(
add_system_prompt(
messages,
self._container_tag,
self._logger,
self._options.mode,
self._get_api_key(),
)
)
except RuntimeError as e:
if "cannot be called from a running event loop" in str(e):
# We're in an async context, run in a separate thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(
asyncio.run,
add_system_prompt(
messages,
self._container_tag,
self._logger,
self._options.mode,
self._get_api_key(),
),
)
enhanced_messages = future.result()
else:
raise
kwargs["messages"] = enhanced_messages
return original_create(**kwargs)
async def wait_for_background_tasks(self, timeout: Optional[float] = 10.0) -> None:
"""
Wait for all background memory storage tasks to complete.
Args:
timeout: Maximum time to wait in seconds. None for no timeout.
Raises:
asyncio.TimeoutError: If tasks don't complete within timeout
"""
if not self._background_tasks:
return
self._logger.debug(
f"Waiting for {len(self._background_tasks)} background tasks to complete"
)
try:
if timeout is not None:
await asyncio.wait_for(
asyncio.gather(*self._background_tasks, return_exceptions=True),
timeout=timeout,
)
else:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._logger.debug("All background tasks completed")
except asyncio.TimeoutError:
self._logger.warn(
f"Background tasks did not complete within {timeout}s timeout"
)
# Cancel remaining tasks
tasks_to_cancel = [task for task in self._background_tasks if not task.done()]
for task in tasks_to_cancel:
task.cancel()
if tasks_to_cancel:
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
raise
def cancel_background_tasks(self) -> None:
"""Cancel all pending background tasks."""
cancelled_count = 0
for task in self._background_tasks:
if not task.done():
task.cancel()
cancelled_count += 1
if cancelled_count > 0:
self._logger.debug(f"Cancelled {cancelled_count} pending background tasks")
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit - wait for background tasks."""
try:
await self.wait_for_background_tasks(timeout=5.0)
except asyncio.TimeoutError:
self._logger.warn("Some background memory tasks did not complete on exit")
def __enter__(self):
"""Sync context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Sync context manager exit - attempt to wait for background tasks."""
if self._background_tasks:
try:
# Try to wait for background tasks in sync context
asyncio.run(self.wait_for_background_tasks(timeout=5.0))
except RuntimeError as e:
if "cannot be called from a running event loop" in str(e):
# In async context, just cancel the tasks
self._logger.warn(
"Cannot wait for background tasks in sync context from async environment. "
"Use async context manager or call wait_for_background_tasks() manually."
)
self.cancel_background_tasks()
else:
raise
except asyncio.TimeoutError:
self._logger.warn(
"Some background memory tasks did not complete on exit"
)
self.cancel_background_tasks()
def __getattr__(self, name: str) -> Any:
"""Delegate all other attributes to the wrapped client."""
return getattr(self._client, name)
def with_supermemory(
openai_client: Union[OpenAI, AsyncOpenAI],
options: OpenAIMiddlewareOptions,
) -> Union[OpenAI, AsyncOpenAI]:
"""
Wraps an OpenAI client with SuperMemory middleware to automatically inject relevant memories
into the system prompt based on the user's message content.
This middleware searches the supermemory API for relevant memories using the container tag
and user message, then either appends memories to an existing system prompt or creates
a new system prompt with the memories.
Args:
openai_client: The OpenAI client to wrap with SuperMemory middleware
options: Configuration options for the middleware (container_tag and custom_id are required)
Returns:
An OpenAI client with SuperMemory middleware injected
Example:
```python
from supermemory_openai import with_supermemory, OpenAIMiddlewareOptions
from openai import OpenAI
# Create OpenAI client with supermemory middleware
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
openai_with_supermemory = with_supermemory(
openai,
OpenAIMiddlewareOptions(
container_tag="user-123",
custom_id="conversation-456",
mode="full",
add_memory="always"
)
)
# Use normally - memories will be automatically injected
response = await openai_with_supermemory.chat.completions.create(
model="gpt-4",
messages=[
{"role": "user", "content": "What's my favorite programming language?"}
]
)
```
Raises:
ValueError: When SUPERMEMORY_API_KEY environment variable is not set
Exception: When supermemory API request fails
"""
wrapper = SupermemoryOpenAIWrapper(openai_client, options)
# Return the wrapper, which delegates all attributes to the original client
return cast(Union[OpenAI, AsyncOpenAI], wrapper)