mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-04-28 11:40:25 +00:00
Co-authored-by: Douglas <douglas.ym.lai@gmail.com> Co-authored-by: a7m-1st <ahmed.jimi.awelkair500@gmail.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tong Chen <web_chentong@163.com>
500 lines
17 KiB
Python
500 lines
17 KiB
Python
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
|
|
|
"""Redis utilities for managing WebSocket sessions and real-time data."""
|
|
|
|
import redis
|
|
from redis import Redis
|
|
from typing import Optional, Dict, Any, Set, Callable
|
|
from datetime import datetime, timezone
|
|
import json
|
|
import logging
|
|
import os
|
|
import asyncio
|
|
|
|
logger = logging.getLogger("server_redis_utils")
|
|
|
|
|
|
class RedisSessionManager:
|
|
"""Manages WebSocket sessions in Redis for scalability and persistence."""
|
|
|
|
def __init__(self, redis_url: Optional[str] = None):
|
|
"""Initialize Redis connection.
|
|
|
|
Args:
|
|
redis_url: Redis connection URL. If None, reads from environment.
|
|
"""
|
|
self.redis_url = redis_url or os.getenv("SESSION_REDIS_URL", "redis://localhost:6379/0")
|
|
self._client: Optional[Redis] = None
|
|
|
|
# Key prefixes
|
|
self.SESSION_PREFIX = "ws:session:"
|
|
self.USER_SESSIONS_PREFIX = "ws:user:sessions:"
|
|
self.PENDING_PREFIX = "ws:pending:"
|
|
self.PUBSUB_CHANNEL = "ws:executions"
|
|
self.DELIVERY_CONFIRMATION_PREFIX = "ws:delivery:"
|
|
|
|
# TTL for sessions (24 hours)
|
|
self.SESSION_TTL = 86400
|
|
# TTL for delivery confirmations (5 minutes)
|
|
self.DELIVERY_TTL = 300
|
|
|
|
# Pub/Sub
|
|
self._pubsub = None
|
|
self._pubsub_client: Optional[Redis] = None
|
|
|
|
@property
|
|
def client(self) -> Redis:
|
|
"""Get or create Redis client."""
|
|
if self._client is None:
|
|
try:
|
|
self._client = redis.from_url(
|
|
self.redis_url,
|
|
decode_responses=True,
|
|
socket_connect_timeout=5,
|
|
socket_timeout=5
|
|
)
|
|
# Test connection
|
|
self._client.ping()
|
|
logger.info("Redis connection established", extra={"url": self.redis_url})
|
|
except Exception as e:
|
|
logger.error("Failed to connect to Redis", extra={"error": str(e)}, exc_info=True)
|
|
raise
|
|
return self._client
|
|
|
|
def store_session(
|
|
self,
|
|
session_id: str,
|
|
user_id: str,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> bool:
|
|
"""Store a WebSocket session in Redis.
|
|
|
|
Args:
|
|
session_id: Unique session identifier
|
|
user_id: User ID associated with the session
|
|
metadata: Additional metadata to store
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
session_data = {
|
|
"user_id": user_id,
|
|
"session_id": session_id,
|
|
"connected_at": datetime.now(timezone.utc).isoformat(),
|
|
**(metadata or {})
|
|
}
|
|
|
|
session_key = f"{self.SESSION_PREFIX}{session_id}"
|
|
user_sessions_key = f"{self.USER_SESSIONS_PREFIX}{user_id}"
|
|
|
|
# Store session data
|
|
self.client.setex(
|
|
session_key,
|
|
self.SESSION_TTL,
|
|
json.dumps(session_data)
|
|
)
|
|
|
|
# Add session to user's session set
|
|
self.client.sadd(user_sessions_key, session_id)
|
|
self.client.expire(user_sessions_key, self.SESSION_TTL)
|
|
|
|
logger.debug("Session stored in Redis", extra={
|
|
"session_id": session_id,
|
|
"user_id": user_id
|
|
})
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to store session in Redis", extra={
|
|
"session_id": session_id,
|
|
"error": str(e)
|
|
}, exc_info=True)
|
|
return False
|
|
|
|
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get session data from Redis.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
Session data dictionary or None if not found
|
|
"""
|
|
try:
|
|
session_key = f"{self.SESSION_PREFIX}{session_id}"
|
|
data = self.client.get(session_key)
|
|
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get session from Redis", extra={
|
|
"session_id": session_id,
|
|
"error": str(e)
|
|
})
|
|
return None
|
|
|
|
def remove_session(self, session_id: str) -> bool:
|
|
"""Remove a session from Redis.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
# Get session data to find user_id
|
|
session = self.get_session(session_id)
|
|
if not session:
|
|
return False
|
|
|
|
user_id = session.get("user_id")
|
|
|
|
# Remove session data
|
|
session_key = f"{self.SESSION_PREFIX}{session_id}"
|
|
self.client.delete(session_key)
|
|
|
|
# Remove from user's session set
|
|
if user_id:
|
|
user_sessions_key = f"{self.USER_SESSIONS_PREFIX}{user_id}"
|
|
self.client.srem(user_sessions_key, session_id)
|
|
|
|
# Remove pending executions
|
|
pending_key = f"{self.PENDING_PREFIX}{session_id}"
|
|
self.client.delete(pending_key)
|
|
|
|
logger.debug("Session removed from Redis", extra={
|
|
"session_id": session_id,
|
|
"user_id": user_id
|
|
})
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to remove session from Redis", extra={
|
|
"session_id": session_id,
|
|
"error": str(e)
|
|
}, exc_info=True)
|
|
return False
|
|
|
|
def get_user_sessions(self, user_id: str) -> Set[str]:
|
|
"""Get all active session IDs for a user.
|
|
|
|
Args:
|
|
user_id: User identifier
|
|
|
|
Returns:
|
|
Set of session IDs
|
|
"""
|
|
try:
|
|
user_sessions_key = f"{self.USER_SESSIONS_PREFIX}{user_id}"
|
|
sessions = self.client.smembers(user_sessions_key)
|
|
return sessions if sessions else set()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get user sessions from Redis", extra={
|
|
"user_id": user_id,
|
|
"error": str(e)
|
|
})
|
|
return set()
|
|
|
|
def add_pending_execution(self, session_id: str, execution_id: str) -> bool:
|
|
"""Add a pending execution to a session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
execution_id: Execution identifier
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
pending_key = f"{self.PENDING_PREFIX}{session_id}"
|
|
self.client.sadd(pending_key, execution_id)
|
|
self.client.expire(pending_key, self.SESSION_TTL)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to add pending execution", extra={
|
|
"session_id": session_id,
|
|
"execution_id": execution_id,
|
|
"error": str(e)
|
|
})
|
|
return False
|
|
|
|
def remove_pending_execution(self, session_id: str, execution_id: str) -> bool:
|
|
"""Remove a pending execution from a session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
execution_id: Execution identifier
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
pending_key = f"{self.PENDING_PREFIX}{session_id}"
|
|
self.client.srem(pending_key, execution_id)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to remove pending execution", extra={
|
|
"session_id": session_id,
|
|
"execution_id": execution_id,
|
|
"error": str(e)
|
|
})
|
|
return False
|
|
|
|
def get_pending_executions(self, session_id: str) -> Set[str]:
|
|
"""Get all pending executions for a session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
Set of execution IDs
|
|
"""
|
|
try:
|
|
pending_key = f"{self.PENDING_PREFIX}{session_id}"
|
|
pending = self.client.smembers(pending_key)
|
|
return pending if pending else set()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get pending executions", extra={
|
|
"session_id": session_id,
|
|
"error": str(e)
|
|
})
|
|
return set()
|
|
|
|
def update_session_ttl(self, session_id: str) -> bool:
|
|
"""Refresh the TTL for a session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
session_key = f"{self.SESSION_PREFIX}{session_id}"
|
|
self.client.expire(session_key, self.SESSION_TTL)
|
|
|
|
pending_key = f"{self.PENDING_PREFIX}{session_id}"
|
|
self.client.expire(pending_key, self.SESSION_TTL)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to update session TTL", extra={
|
|
"session_id": session_id,
|
|
"error": str(e)
|
|
})
|
|
return False
|
|
|
|
def confirm_delivery(self, execution_id: str, session_id: str) -> bool:
|
|
"""Confirm that a message was delivered to a WebSocket client.
|
|
|
|
Args:
|
|
execution_id: The execution ID that was delivered
|
|
session_id: The session ID that received the message
|
|
|
|
Returns:
|
|
True if confirmation was stored, False otherwise
|
|
"""
|
|
try:
|
|
confirmation_key = f"{self.DELIVERY_CONFIRMATION_PREFIX}{execution_id}"
|
|
confirmation_data = json.dumps({
|
|
"execution_id": execution_id,
|
|
"session_id": session_id,
|
|
"delivered_at": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
self.client.setex(confirmation_key, self.DELIVERY_TTL, confirmation_data)
|
|
logger.debug("Delivery confirmed", extra={
|
|
"execution_id": execution_id,
|
|
"session_id": session_id
|
|
})
|
|
return True
|
|
except Exception as e:
|
|
logger.error("Failed to confirm delivery", extra={
|
|
"execution_id": execution_id,
|
|
"session_id": session_id,
|
|
"error": str(e)
|
|
})
|
|
return False
|
|
|
|
async def wait_for_delivery(
|
|
self,
|
|
execution_id: str,
|
|
timeout: float = 10.0,
|
|
poll_interval: float = 0.1
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Wait for delivery confirmation of an execution.
|
|
|
|
Args:
|
|
execution_id: The execution ID to wait for
|
|
timeout: Maximum time to wait in seconds
|
|
poll_interval: Time between checks in seconds
|
|
|
|
Returns:
|
|
Confirmation data if delivered, None if timeout
|
|
"""
|
|
confirmation_key = f"{self.DELIVERY_CONFIRMATION_PREFIX}{execution_id}"
|
|
elapsed = 0.0
|
|
|
|
while elapsed < timeout:
|
|
try:
|
|
data = self.client.get(confirmation_key)
|
|
if data:
|
|
# Clean up the confirmation key
|
|
self.client.delete(confirmation_key)
|
|
return json.loads(data)
|
|
except Exception as e:
|
|
logger.error("Error checking delivery confirmation", extra={
|
|
"execution_id": execution_id,
|
|
"error": str(e)
|
|
})
|
|
|
|
await asyncio.sleep(poll_interval)
|
|
elapsed += poll_interval
|
|
|
|
logger.warning("Delivery confirmation timeout", extra={
|
|
"execution_id": execution_id,
|
|
"timeout": timeout
|
|
})
|
|
return None
|
|
|
|
def has_active_sessions_for_user(self, user_id: str) -> bool:
|
|
"""Check if a user has any active WebSocket sessions.
|
|
|
|
Args:
|
|
user_id: User identifier
|
|
|
|
Returns:
|
|
True if user has active sessions, False otherwise
|
|
"""
|
|
try:
|
|
sessions = self.get_user_sessions(user_id)
|
|
return len(sessions) > 0
|
|
except Exception as e:
|
|
logger.error("Failed to check user sessions", extra={
|
|
"user_id": user_id,
|
|
"error": str(e)
|
|
})
|
|
return False
|
|
|
|
def close(self):
|
|
"""Close Redis connection."""
|
|
if self._pubsub:
|
|
self._pubsub.close()
|
|
self._pubsub = None
|
|
if self._pubsub_client:
|
|
self._pubsub_client.close()
|
|
self._pubsub_client = None
|
|
if self._client:
|
|
self._client.close()
|
|
self._client = None
|
|
|
|
def publish_execution_event(self, event_data: Dict[str, Any]) -> bool:
|
|
"""Publish an execution event to all workers via Redis pub/sub.
|
|
|
|
Args:
|
|
event_data: Event data to broadcast
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
message = json.dumps(event_data)
|
|
self.client.publish(self.PUBSUB_CHANNEL, message)
|
|
logger.debug("Published execution event to Redis", extra={
|
|
"execution_id": event_data.get("execution_id"),
|
|
"type": event_data.get("type")
|
|
})
|
|
return True
|
|
except Exception as e:
|
|
logger.error("Failed to publish execution event", extra={
|
|
"error": str(e)
|
|
}, exc_info=True)
|
|
return False
|
|
|
|
async def subscribe_to_execution_events(self, callback: Callable[[Dict[str, Any]], None]):
|
|
"""Subscribe to execution events from Redis pub/sub.
|
|
|
|
This should be run in a background task. It will call the callback
|
|
for each message received on the pub/sub channel.
|
|
|
|
Args:
|
|
callback: Async function to call with each event
|
|
"""
|
|
try:
|
|
# Create separate Redis client for pub/sub (can't use the same one)
|
|
if self._pubsub_client is None:
|
|
self._pubsub_client = redis.from_url(
|
|
self.redis_url,
|
|
decode_responses=True,
|
|
socket_connect_timeout=5,
|
|
socket_timeout=5
|
|
)
|
|
|
|
self._pubsub = self._pubsub_client.pubsub()
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
self._pubsub.subscribe,
|
|
self.PUBSUB_CHANNEL
|
|
)
|
|
|
|
logger.info("Subscribed to execution events", extra={
|
|
"channel": self.PUBSUB_CHANNEL
|
|
})
|
|
|
|
# Listen for messages
|
|
while True:
|
|
message = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
self._pubsub.get_message,
|
|
True, # ignore_subscribe_messages
|
|
1.0 # timeout
|
|
)
|
|
|
|
if message and message['type'] == 'message':
|
|
try:
|
|
event_data = json.loads(message['data'])
|
|
await callback(event_data)
|
|
except Exception as e:
|
|
logger.error("Error processing pub/sub message", extra={
|
|
"error": str(e)
|
|
}, exc_info=True)
|
|
|
|
# Small sleep to prevent tight loop
|
|
await asyncio.sleep(0.01)
|
|
|
|
except Exception as e:
|
|
logger.error("Pub/sub subscription error", extra={
|
|
"error": str(e)
|
|
}, exc_info=True)
|
|
|
|
|
|
# Global instance
|
|
_redis_manager: Optional[RedisSessionManager] = None
|
|
|
|
|
|
def get_redis_manager() -> RedisSessionManager:
|
|
"""Get or create the global Redis session manager."""
|
|
global _redis_manager
|
|
if _redis_manager is None:
|
|
_redis_manager = RedisSessionManager()
|
|
return _redis_manager
|