eigent/backend/app/utils/oauth_state_manager.py
bytecraftii 49e148a2f9
Add langfuse and update logger (#952)
Co-authored-by: bytecraftii <bytecraftii@users.noreply.github.com>
Co-authored-by: Wendong-Fan <w3ndong.fan@gmail.com>
2026-01-25 08:13:07 +08:00

108 lines
3.9 KiB
Python

# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
"""
OAuth authorization state manager for background authorization flows
"""
import threading
from typing import Dict, Optional, Literal, Any
from datetime import datetime
import logging
logger = logging.getLogger("main")
AuthStatus = Literal["pending", "authorizing", "success", "failed", "cancelled"]
class OAuthState:
"""Represents the state of an OAuth authorization flow"""
def __init__(self, provider: str):
self.provider = provider
self.status: AuthStatus = "pending"
self.error: Optional[str] = None
self.thread: Optional[threading.Thread] = None
self.result: Optional[Any] = None
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self._cancel_event = threading.Event()
self.server = None # Store the local server instance for forced shutdown
def is_cancelled(self) -> bool:
"""Check if cancellation has been requested"""
return self._cancel_event.is_set()
def cancel(self):
"""Request cancellation of the authorization flow"""
self._cancel_event.set()
self.status = "cancelled"
self.completed_at = datetime.now()
def to_dict(self) -> Dict:
"""Convert state to dictionary for API response"""
return {
"provider": self.provider,
"status": self.status,
"error": self.error,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}
class OAuthStateManager:
"""Manager for tracking OAuth authorization flows"""
def __init__(self):
self._states: Dict[str, OAuthState] = {}
self._lock = threading.Lock()
def create_state(self, provider: str) -> OAuthState:
"""Create a new OAuth state for a provider"""
with self._lock:
# Cancel any existing authorization for this provider
if provider in self._states:
old_state = self._states[provider]
if old_state.status in ["pending", "authorizing"]:
old_state.cancel()
logger.info(f"Cancelled previous {provider} authorization")
state = OAuthState(provider)
self._states[provider] = state
return state
def get_state(self, provider: str) -> Optional[OAuthState]:
"""Get the current state for a provider"""
with self._lock:
return self._states.get(provider)
def update_status(
self,
provider: str,
status: AuthStatus,
error: Optional[str] = None,
result: Optional[Any] = None
):
"""Update the status of an authorization flow"""
with self._lock:
if provider in self._states:
state = self._states[provider]
state.status = status
state.error = error
state.result = result
if status in ["success", "failed", "cancelled"]:
state.completed_at = datetime.now()
logger.info(f"Updated {provider} OAuth status to {status}")
# Global instance
oauth_state_manager = OAuthStateManager()