eigent/backend/app/utils/oauth_state_manager.py
2025-11-10 15:22:06 +08:00

94 lines
3.3 KiB
Python

"""
OAuth authorization state manager for background authorization flows
"""
import threading
from typing import Dict, Optional, Literal, Any
from datetime import datetime
from utils import traceroot_wrapper as traceroot
logger = traceroot.get_logger("main")
AuthStatus = Literal["pending", "authorizing", "success", "failed", "cancelled"]
class OAuthState:
"""Represents the state of an OAuth authorization flow"""
def __init__(self, provider: str):
self.provider = provider
self.status: AuthStatus = "pending"
self.error: Optional[str] = None
self.thread: Optional[threading.Thread] = None
self.result: Optional[Any] = None
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self._cancel_event = threading.Event()
self.server = None # Store the local server instance for forced shutdown
def is_cancelled(self) -> bool:
"""Check if cancellation has been requested"""
return self._cancel_event.is_set()
def cancel(self):
"""Request cancellation of the authorization flow"""
self._cancel_event.set()
self.status = "cancelled"
self.completed_at = datetime.now()
def to_dict(self) -> Dict:
"""Convert state to dictionary for API response"""
return {
"provider": self.provider,
"status": self.status,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class OAuthStateManager:
"""Manager for tracking OAuth authorization flows"""
def __init__(self):
self._states: Dict[str, OAuthState] = {}
self._lock = threading.Lock()
def create_state(self, provider: str) -> OAuthState:
"""Create a new OAuth state for a provider"""
with self._lock:
# Cancel any existing authorization for this provider
if provider in self._states:
old_state = self._states[provider]
if old_state.status in ["pending", "authorizing"]:
old_state.cancel()
logger.info(f"Cancelled previous {provider} authorization")
state = OAuthState(provider)
self._states[provider] = state
return state
def get_state(self, provider: str) -> Optional[OAuthState]:
"""Get the current state for a provider"""
with self._lock:
return self._states.get(provider)
def update_status(
self,
provider: str,
status: AuthStatus,
error: Optional[str] = None,
result: Optional[Any] = None
):
"""Update the status of an authorization flow"""
with self._lock:
if provider in self._states:
state = self._states[provider]
state.status = status
state.error = error
state.result = result
if status in ["success", "failed", "cancelled"]:
state.completed_at = datetime.now()
logger.info(f"Updated {provider} OAuth status to {status}")
# Global instance
oauth_state_manager = OAuthStateManager()