mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-04-28 03:30:06 +00:00
revert: codex feature (#1240)
This commit is contained in:
parent
3e57ffdf79
commit
5e3f46a32b
10 changed files with 273 additions and 1946 deletions
|
|
@ -1,234 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils.codex_oauth import CodexOAuthManager
|
||||
from app.utils.oauth_state_manager import oauth_state_manager
|
||||
|
||||
|
||||
class CodexTokenRequest(BaseModel):
|
||||
r"""Request model for saving Codex/OpenAI API token."""
|
||||
|
||||
access_token: str
|
||||
expires_in: int | None = None
|
||||
|
||||
|
||||
logger = logging.getLogger("codex_controller")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/codex/connect", name="connect codex")
|
||||
async def connect_codex():
|
||||
r"""Connect to Codex/OpenAI via OAuth PKCE flow.
|
||||
|
||||
Initiates or completes the Codex OAuth authorization flow
|
||||
to obtain an OpenAI API key.
|
||||
|
||||
Returns:
|
||||
Connection result with access token and provider info
|
||||
"""
|
||||
try:
|
||||
if CodexOAuthManager.is_authenticated():
|
||||
if CodexOAuthManager.is_token_expired():
|
||||
# Try refreshing first
|
||||
if CodexOAuthManager.refresh_token_if_needed():
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Codex token refreshed successfully",
|
||||
"toolkit_name": "CodexOAuthManager",
|
||||
"access_token": CodexOAuthManager.get_access_token(),
|
||||
"provider_name": "openai",
|
||||
"endpoint_url": "https://api.openai.com/v1",
|
||||
}
|
||||
# Refresh failed, start new auth
|
||||
logger.info(
|
||||
"Codex token expired and refresh failed, starting re-auth"
|
||||
)
|
||||
CodexOAuthManager.start_background_auth()
|
||||
return {
|
||||
"success": False,
|
||||
"status": "authorizing",
|
||||
"message": "Token expired. Browser should"
|
||||
" open for re-authorization.",
|
||||
"toolkit_name": "CodexOAuthManager",
|
||||
"requires_auth": True,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Codex/OpenAI is already authenticated",
|
||||
"toolkit_name": "CodexOAuthManager",
|
||||
"access_token": CodexOAuthManager.get_access_token(),
|
||||
"provider_name": "openai",
|
||||
"endpoint_url": "https://api.openai.com/v1",
|
||||
}
|
||||
else:
|
||||
logger.info("No Codex credentials found, starting OAuth flow")
|
||||
CodexOAuthManager.start_background_auth()
|
||||
return {
|
||||
"success": False,
|
||||
"status": "authorizing",
|
||||
"message": "Authorization required. Browser"
|
||||
" should open automatically.",
|
||||
"toolkit_name": "CodexOAuthManager",
|
||||
"requires_auth": True,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect Codex: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to connect Codex: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/codex/disconnect", name="disconnect codex")
|
||||
async def disconnect_codex():
|
||||
r"""Disconnect Codex/OpenAI and clean up authentication data.
|
||||
|
||||
Cancels any active OAuth flow and clears stored tokens.
|
||||
|
||||
Returns:
|
||||
Disconnection result
|
||||
"""
|
||||
try:
|
||||
# Cancel any active OAuth flow
|
||||
state = oauth_state_manager.get_state("codex")
|
||||
if state and state.status in ["pending", "authorizing"]:
|
||||
state.cancel()
|
||||
if hasattr(state, "server") and state.server:
|
||||
try:
|
||||
state.server.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
oauth_state_manager._states.pop("codex", None)
|
||||
|
||||
success = CodexOAuthManager.clear_token()
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": (
|
||||
"Successfully disconnected Codex"
|
||||
" and cleaned up"
|
||||
" authentication tokens"
|
||||
),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Disconnected Codex (no tokens found to clean up)",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disconnect Codex: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to disconnect Codex: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/codex/save-token", name="save codex token")
|
||||
async def save_codex_token(token_request: CodexTokenRequest):
|
||||
r"""Save Codex/OpenAI API token (manual API key entry fallback).
|
||||
|
||||
Args:
|
||||
token_request: Token data containing access_token
|
||||
and optionally expires_in
|
||||
|
||||
Returns:
|
||||
Save result
|
||||
"""
|
||||
try:
|
||||
token_data = token_request.model_dump(exclude_none=True)
|
||||
token_data["manual"] = True
|
||||
|
||||
success = CodexOAuthManager.save_token(token_data)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Codex token saved successfully",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to save Codex token"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Codex token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to save token: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/codex/status", name="get codex status")
|
||||
async def get_codex_status():
|
||||
r"""Get current Codex/OpenAI authentication status and token info.
|
||||
|
||||
Returns:
|
||||
Status information including authentication state and token expiry
|
||||
"""
|
||||
try:
|
||||
is_authenticated = CodexOAuthManager.is_authenticated()
|
||||
|
||||
if not is_authenticated:
|
||||
return {
|
||||
"authenticated": False,
|
||||
"status": "not_configured",
|
||||
"message": "Codex not configured. OAuth or API key required.",
|
||||
}
|
||||
|
||||
token_info = CodexOAuthManager.get_token_info()
|
||||
is_expired = CodexOAuthManager.is_token_expired()
|
||||
is_expiring_soon = CodexOAuthManager.is_token_expiring_soon()
|
||||
|
||||
result = {
|
||||
"authenticated": True,
|
||||
"status": "expired"
|
||||
if is_expired
|
||||
else ("expiring_soon" if is_expiring_soon else "valid"),
|
||||
}
|
||||
|
||||
if token_info:
|
||||
if token_info.get("expires_at"):
|
||||
current_time = int(time.time())
|
||||
expires_at = token_info["expires_at"]
|
||||
seconds_remaining = max(0, expires_at - current_time)
|
||||
result["expires_at"] = expires_at
|
||||
result["seconds_remaining"] = seconds_remaining
|
||||
|
||||
if token_info.get("saved_at"):
|
||||
result["saved_at"] = token_info["saved_at"]
|
||||
|
||||
if token_info.get("manual"):
|
||||
result["manual"] = True
|
||||
|
||||
if is_expired:
|
||||
result["message"] = "Token has expired. Please re-authenticate."
|
||||
elif is_expiring_soon:
|
||||
result["message"] = (
|
||||
"Token is expiring soon. Consider re-authenticating."
|
||||
)
|
||||
else:
|
||||
result["message"] = "Codex/OpenAI is connected and token is valid."
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Codex status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get status: {str(e)}"
|
||||
)
|
||||
|
|
@ -23,7 +23,6 @@ from fastapi import FastAPI
|
|||
|
||||
from app.controller import (
|
||||
chat_controller,
|
||||
codex_controller,
|
||||
health_controller,
|
||||
model_controller,
|
||||
task_controller,
|
||||
|
|
@ -72,11 +71,6 @@ def register_routers(app: FastAPI, prefix: str = "") -> None:
|
|||
"tags": ["tool"],
|
||||
"description": "Tool installation and management",
|
||||
},
|
||||
{
|
||||
"router": codex_controller.router,
|
||||
"tags": ["codex"],
|
||||
"description": "Codex OAuth provider authentication",
|
||||
},
|
||||
]
|
||||
|
||||
for config in routers_config:
|
||||
|
|
|
|||
|
|
@ -1,548 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
r"""OpenAI Codex OAuth manager.
|
||||
|
||||
Handles Authorization Code + PKCE flow using Codex CLI's public client_id.
|
||||
The resulting access token is stored in an encrypted file and used
|
||||
independently of the OPENAI_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import secrets
|
||||
import socket
|
||||
import stat
|
||||
import threading
|
||||
import time
|
||||
import webbrowser
|
||||
from html import escape as html_escape
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from filelock import FileLock
|
||||
|
||||
from app.utils.oauth_state_manager import oauth_state_manager
|
||||
|
||||
logger = logging.getLogger("codex_oauth")
|
||||
|
||||
# OpenAI / Codex OAuth constants
|
||||
# Fixed public client_id from the Codex CLI (not a secret).
|
||||
CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_AUTH_URL = "https://auth.openai.com/oauth/authorize"
|
||||
CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
# Fixed callback port used by Codex CLI
|
||||
CODEX_CALLBACK_PORT = 1455
|
||||
|
||||
# Token storage path
|
||||
CODEX_TOKEN_DIR = os.path.join(
|
||||
os.path.expanduser("~"), ".eigent", "tokens", "codex"
|
||||
)
|
||||
CODEX_TOKEN_PATH = os.path.join(CODEX_TOKEN_DIR, "codex_token.enc")
|
||||
|
||||
# Token lifetime defaults (seconds)
|
||||
CODEX_TOKEN_DEFAULT_LIFETIME = 3600 # 1 hour
|
||||
CODEX_TOKEN_REFRESH_THRESHOLD = 300 # 5 minutes before expiry
|
||||
|
||||
|
||||
def _get_machine_identifier() -> bytes:
|
||||
r"""Get a machine-specific identifier for key derivation.
|
||||
|
||||
Combines multiple machine-specific values to create a stable identifier
|
||||
that is unique to this machine but consistent across restarts.
|
||||
|
||||
Returns:
|
||||
Machine identifier as bytes.
|
||||
"""
|
||||
components = [
|
||||
getpass.getuser(),
|
||||
socket.gethostname(),
|
||||
platform.node(),
|
||||
# Add home directory path for additional uniqueness
|
||||
os.path.expanduser("~"),
|
||||
]
|
||||
|
||||
# Try to get machine-id on Linux
|
||||
machine_id_paths = [
|
||||
"/etc/machine-id",
|
||||
"/var/lib/dbus/machine-id",
|
||||
]
|
||||
for path in machine_id_paths:
|
||||
try:
|
||||
with open(path) as f:
|
||||
components.append(f.read().strip())
|
||||
break
|
||||
except (FileNotFoundError, PermissionError):
|
||||
continue
|
||||
|
||||
return "|".join(components).encode("utf-8")
|
||||
|
||||
|
||||
def _derive_encryption_key() -> bytes:
|
||||
r"""Derive an encryption key from machine-specific identifiers.
|
||||
|
||||
Uses PBKDF2 to derive a Fernet-compatible key from machine identifiers.
|
||||
This ties the encryption to the specific machine without storing a key file.
|
||||
|
||||
Returns:
|
||||
The Fernet encryption key as bytes.
|
||||
"""
|
||||
# Fixed salt for this application (not secret, just ensures uniqueness)
|
||||
# The security comes from the machine-specific identifier
|
||||
app_salt = b"eigent-codex-token-encryption-v1"
|
||||
|
||||
machine_id = _get_machine_identifier()
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=app_salt,
|
||||
iterations=100_000,
|
||||
)
|
||||
|
||||
# Derive a 32-byte key and encode it as base64 for Fernet
|
||||
derived_key = kdf.derive(machine_id)
|
||||
return base64.urlsafe_b64encode(derived_key)
|
||||
|
||||
|
||||
def _encrypt_token_data(token_data: dict) -> bytes:
|
||||
r"""Encrypt token data using Fernet symmetric encryption.
|
||||
|
||||
Args:
|
||||
token_data: Dictionary containing token information.
|
||||
|
||||
Returns:
|
||||
Encrypted bytes.
|
||||
"""
|
||||
key = _derive_encryption_key()
|
||||
fernet = Fernet(key)
|
||||
json_bytes = json.dumps(token_data).encode("utf-8")
|
||||
return fernet.encrypt(json_bytes)
|
||||
|
||||
|
||||
def _decrypt_token_data(encrypted_data: bytes) -> dict | None:
|
||||
r"""Decrypt token data.
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted bytes from file.
|
||||
|
||||
Returns:
|
||||
Decrypted token dictionary, or None if decryption fails.
|
||||
"""
|
||||
try:
|
||||
key = _derive_encryption_key()
|
||||
fernet = Fernet(key)
|
||||
decrypted = fernet.decrypt(encrypted_data)
|
||||
return json.loads(decrypted.decode("utf-8"))
|
||||
except (InvalidToken, json.JSONDecodeError) as e:
|
||||
logger.warning("Failed to decrypt token data: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_pkce_pair() -> tuple[str, str]:
|
||||
r"""Generate a PKCE code_verifier and S256 code_challenge.
|
||||
|
||||
Returns:
|
||||
Tuple of (code_verifier, code_challenge).
|
||||
"""
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
)
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
||||
r"""HTTP handler that captures the OAuth callback code."""
|
||||
|
||||
def do_GET(self):
|
||||
parsed = urlparse(self.path)
|
||||
params = parse_qs(parsed.query)
|
||||
|
||||
# Validate state parameter to prevent CSRF attacks
|
||||
received_state = params.get("state", [None])[0]
|
||||
expected_state = getattr(self.server, "expected_state", None)
|
||||
|
||||
if expected_state and received_state != expected_state:
|
||||
self.server.auth_error = "state_mismatch: Invalid state parameter"
|
||||
self.send_response(400)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
b"<html><body><h1>Authorization failed</h1>"
|
||||
b"<p>Invalid state parameter. Possible CSRF attack.</p>"
|
||||
b"</body></html>"
|
||||
)
|
||||
return
|
||||
|
||||
if "code" in params:
|
||||
self.server.auth_code = params["code"][0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
b"<html><body>"
|
||||
b"<h1>Authorization successful!</h1>"
|
||||
b"<p>You can close this window and return to Eigent.</p>"
|
||||
b"</body></html>"
|
||||
)
|
||||
elif "error" in params:
|
||||
error = params.get("error", ["unknown"])[0]
|
||||
desc = params.get("error_description", [""])[0]
|
||||
self.server.auth_error = f"{error}: {desc}"
|
||||
self.send_response(400)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
# Escape HTML to prevent XSS from query parameters
|
||||
self.wfile.write(
|
||||
f"<html><body><h1>Authorization failed</h1>"
|
||||
f"<p>{html_escape(error)}: {html_escape(desc)}</p>"
|
||||
f"</body></html>".encode()
|
||||
)
|
||||
else:
|
||||
self.send_response(400)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
b"<html><body><h1>"
|
||||
b"Missing authorization code"
|
||||
b"</h1></body></html>"
|
||||
)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
logger.debug("Codex callback server: %s", format % args)
|
||||
|
||||
|
||||
class CodexOAuthManager:
|
||||
r"""Manages OpenAI Codex OAuth token lifecycle."""
|
||||
|
||||
@staticmethod
|
||||
def _token_path() -> str:
|
||||
return CODEX_TOKEN_PATH
|
||||
|
||||
@classmethod
|
||||
def save_token(cls, token_data: dict) -> bool:
|
||||
r"""Save token data to disk with encryption.
|
||||
|
||||
Args:
|
||||
token_data: Dictionary with at least ``access_token``.
|
||||
|
||||
Returns:
|
||||
True on success.
|
||||
"""
|
||||
path = cls._token_path()
|
||||
token_data = token_data.copy()
|
||||
try:
|
||||
if "saved_at" not in token_data:
|
||||
token_data["saved_at"] = int(time.time())
|
||||
|
||||
# Compute absolute expiry from the relative expires_in value
|
||||
# (if present), then discard expires_in so we only store the
|
||||
# absolute timestamp.
|
||||
if "expires_at" not in token_data:
|
||||
expires_in = token_data.pop(
|
||||
"expires_in", CODEX_TOKEN_DEFAULT_LIFETIME
|
||||
)
|
||||
token_data["expires_at"] = token_data["saved_at"] + expires_in
|
||||
else:
|
||||
token_data.pop("expires_in", None)
|
||||
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
lock = FileLock(path + ".lock")
|
||||
|
||||
# Encrypt token data before saving
|
||||
encrypted_data = _encrypt_token_data(token_data)
|
||||
|
||||
with lock, open(path, "wb") as f:
|
||||
f.write(encrypted_data)
|
||||
|
||||
# Set restrictive permissions on token file
|
||||
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
logger.info("Saved encrypted Codex token to %s", path)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to save Codex token: %s", e)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def load_token(cls) -> dict | None:
|
||||
r"""Load and decrypt token data from disk."""
|
||||
path = cls._token_path()
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
lock = FileLock(path + ".lock")
|
||||
with lock, open(path, "rb") as f:
|
||||
encrypted_data = f.read()
|
||||
return _decrypt_token_data(encrypted_data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load token: %s", e)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def clear_token(cls) -> bool:
|
||||
r"""Remove stored token file."""
|
||||
path = cls._token_path()
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
logger.info("Removed Codex token file: %s", path)
|
||||
|
||||
token_dir = os.path.dirname(path)
|
||||
if os.path.exists(token_dir) and not os.listdir(token_dir):
|
||||
os.rmdir(token_dir)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to clear Codex token: %s", e)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_authenticated(cls) -> bool:
|
||||
r"""Return True if a Codex OAuth token is available."""
|
||||
token = cls.load_token()
|
||||
return bool(token and token.get("access_token"))
|
||||
|
||||
@classmethod
|
||||
def is_token_expired(cls) -> bool:
|
||||
r"""Return True if the stored token has expired."""
|
||||
token = cls.load_token()
|
||||
if not token:
|
||||
return False
|
||||
expires_at = token.get("expires_at")
|
||||
if not expires_at:
|
||||
return False
|
||||
return int(time.time()) >= expires_at
|
||||
|
||||
@classmethod
|
||||
def is_token_expiring_soon(cls) -> bool:
|
||||
r"""Return True if the token expires within the refresh threshold."""
|
||||
token = cls.load_token()
|
||||
if not token:
|
||||
return False
|
||||
expires_at = token.get("expires_at")
|
||||
if not expires_at:
|
||||
return False
|
||||
return (expires_at - int(time.time())) < CODEX_TOKEN_REFRESH_THRESHOLD
|
||||
|
||||
@classmethod
|
||||
def get_access_token(cls) -> str | None:
|
||||
r"""Return the current Codex OAuth access token."""
|
||||
token = cls.load_token()
|
||||
if token and token.get("access_token"):
|
||||
return token["access_token"]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_token_info(cls) -> dict | None:
|
||||
r"""Return stored token metadata."""
|
||||
return cls.load_token()
|
||||
|
||||
@classmethod
|
||||
def refresh_token_if_needed(cls) -> bool:
|
||||
r"""Attempt to refresh the token if it has a refresh_token.
|
||||
|
||||
Returns:
|
||||
True if refreshed or not needed, False on failure.
|
||||
"""
|
||||
token = cls.load_token()
|
||||
if not token:
|
||||
return False
|
||||
|
||||
if not cls.is_token_expiring_soon():
|
||||
return True
|
||||
|
||||
refresh_token = token.get("refresh_token")
|
||||
if not refresh_token:
|
||||
return False
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
CODEX_TOKEN_URL,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": CODEX_CLIENT_ID,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
new_token = resp.json()
|
||||
|
||||
# Merge with existing data
|
||||
token.update(
|
||||
{
|
||||
"access_token": new_token["access_token"],
|
||||
"expires_in": new_token.get(
|
||||
"expires_in", CODEX_TOKEN_DEFAULT_LIFETIME
|
||||
),
|
||||
"saved_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
if new_token.get("refresh_token"):
|
||||
token["refresh_token"] = new_token["refresh_token"]
|
||||
token.pop("expires_at", None)
|
||||
|
||||
return cls.save_token(token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to refresh Codex token: %s", e)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background OAuth flow
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def start_background_auth(cls) -> str:
|
||||
r"""Launch the PKCE OAuth flow in a background thread.
|
||||
|
||||
Returns:
|
||||
``"authorizing"`` immediately.
|
||||
"""
|
||||
# Cancel any existing flow
|
||||
old_state = oauth_state_manager.get_state("codex")
|
||||
if old_state and old_state.status in ["pending", "authorizing"]:
|
||||
old_state.cancel()
|
||||
if hasattr(old_state, "server") and old_state.server:
|
||||
try:
|
||||
old_state.server.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
state = oauth_state_manager.create_state("codex")
|
||||
|
||||
def _auth_flow():
|
||||
try:
|
||||
state.status = "authorizing"
|
||||
oauth_state_manager.update_status("codex", "authorizing")
|
||||
|
||||
code_verifier, code_challenge = _generate_pkce_pair()
|
||||
|
||||
# Generate state parameter to prevent CSRF attacks
|
||||
oauth_state = secrets.token_urlsafe(32)
|
||||
|
||||
# Start localhost callback server on fixed port 1455 (Codex standard)
|
||||
server = HTTPServer(
|
||||
("127.0.0.1", CODEX_CALLBACK_PORT), _CallbackHandler
|
||||
)
|
||||
server.auth_code = None
|
||||
server.auth_error = None
|
||||
server.expected_state = oauth_state
|
||||
state.server = server
|
||||
|
||||
redirect_uri = (
|
||||
f"http://localhost:{CODEX_CALLBACK_PORT}/auth/callback"
|
||||
)
|
||||
|
||||
params = urlencode(
|
||||
{
|
||||
"response_type": "code",
|
||||
"client_id": CODEX_CLIENT_ID,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": "openid profile email offline_access",
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": oauth_state,
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
}
|
||||
)
|
||||
auth_url = f"{CODEX_AUTH_URL}?{params}"
|
||||
|
||||
if state.is_cancelled():
|
||||
server.server_close()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Opening browser for Codex OAuth on port %d",
|
||||
CODEX_CALLBACK_PORT,
|
||||
)
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
# Wait for the callback (single request)
|
||||
server.handle_request()
|
||||
|
||||
if state.is_cancelled():
|
||||
server.server_close()
|
||||
return
|
||||
|
||||
if server.auth_error:
|
||||
raise ValueError(server.auth_error)
|
||||
|
||||
if not server.auth_code:
|
||||
raise ValueError("No authorization code received")
|
||||
|
||||
auth_code = server.auth_code
|
||||
server.server_close()
|
||||
|
||||
# Exchange code for token
|
||||
token_resp = requests.post(
|
||||
CODEX_TOKEN_URL,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": CODEX_CLIENT_ID,
|
||||
"code": auth_code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
token_data = token_resp.json()
|
||||
|
||||
if state.is_cancelled():
|
||||
return
|
||||
|
||||
cls.save_token(token_data)
|
||||
|
||||
oauth_state_manager.update_status(
|
||||
"codex", "success", result=token_data
|
||||
)
|
||||
logger.info("Codex OAuth authorization successful")
|
||||
|
||||
except Exception as e:
|
||||
if state.is_cancelled():
|
||||
oauth_state_manager.update_status("codex", "cancelled")
|
||||
else:
|
||||
logger.error("Codex OAuth failed: %s", e)
|
||||
oauth_state_manager.update_status(
|
||||
"codex", "failed", error=str(e)
|
||||
)
|
||||
finally:
|
||||
state.server = None
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_auth_flow,
|
||||
daemon=True,
|
||||
name=f"Codex-OAuth-{state.started_at.timestamp()}",
|
||||
)
|
||||
state.thread = thread
|
||||
thread.start()
|
||||
|
||||
logger.info("Started background Codex OAuth authorization")
|
||||
return "authorizing"
|
||||
|
|
@ -1,875 +0,0 @@
|
|||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||||
|
||||
r"""Tests for the Codex OAuth manager."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import stat
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.codex_oauth import (
|
||||
CODEX_CLIENT_ID,
|
||||
CODEX_TOKEN_DEFAULT_LIFETIME,
|
||||
CODEX_TOKEN_REFRESH_THRESHOLD,
|
||||
CodexOAuthManager,
|
||||
_CallbackHandler,
|
||||
_decrypt_token_data,
|
||||
_derive_encryption_key,
|
||||
_encrypt_token_data,
|
||||
_generate_pkce_pair,
|
||||
_get_machine_identifier,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_token_path() -> Generator[str, None, None]:
|
||||
r"""Create a temporary token path and patch CodexOAuthManager to use it."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
token_path = os.path.join(temp_dir, "codex_token.enc")
|
||||
with patch.object(
|
||||
CodexOAuthManager, "_token_path", return_value=token_path
|
||||
):
|
||||
yield token_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env() -> Generator[None, None, None]:
|
||||
r"""Ensure OPENAI_API_KEY is cleaned up before and after test."""
|
||||
original = os.environ.pop("OPENAI_API_KEY", None)
|
||||
yield
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
if original is not None:
|
||||
os.environ["OPENAI_API_KEY"] = original
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_callback_request(path: str, expected_state: str | None = None):
|
||||
r"""Create a mock HTTP request for _CallbackHandler testing."""
|
||||
handler = MagicMock(spec=_CallbackHandler)
|
||||
handler.path = path
|
||||
handler.wfile = io.BytesIO()
|
||||
|
||||
# Track response
|
||||
handler.response_code = None
|
||||
handler.headers_sent = {}
|
||||
|
||||
def send_response(code):
|
||||
handler.response_code = code
|
||||
|
||||
def send_header(name, value):
|
||||
handler.headers_sent[name] = value
|
||||
|
||||
def end_headers():
|
||||
pass
|
||||
|
||||
handler.send_response = send_response
|
||||
handler.send_header = send_header
|
||||
handler.end_headers = end_headers
|
||||
|
||||
# Create mock server
|
||||
handler.server = MagicMock()
|
||||
handler.server.auth_code = None
|
||||
handler.server.auth_error = None
|
||||
handler.server.expected_state = expected_state
|
||||
|
||||
yield handler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PKCE Generation Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPKCEGeneration:
|
||||
r"""Tests for PKCE code verifier and challenge generation."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_tuple_of_strings(self):
|
||||
"""PKCE pair should be a tuple of two strings."""
|
||||
verifier, challenge = _generate_pkce_pair()
|
||||
|
||||
assert isinstance(verifier, str)
|
||||
assert isinstance(challenge, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_verifier_meets_minimum_length(self):
|
||||
"""Code verifier should be at least 43 characters (RFC 7636)."""
|
||||
verifier, _ = _generate_pkce_pair()
|
||||
|
||||
assert len(verifier) >= 43
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_challenge_is_base64url_without_padding(self):
|
||||
"""Code challenge should be valid base64url without padding."""
|
||||
_, challenge = _generate_pkce_pair()
|
||||
|
||||
# Base64url should not contain +, /, or =
|
||||
assert "+" not in challenge
|
||||
assert "/" not in challenge
|
||||
assert "=" not in challenge
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generates_unique_values(self):
|
||||
"""Each call should generate cryptographically unique values."""
|
||||
pairs = [_generate_pkce_pair() for _ in range(10)]
|
||||
verifiers = [p[0] for p in pairs]
|
||||
challenges = [p[1] for p in pairs]
|
||||
|
||||
assert len(set(verifiers)) == 10
|
||||
assert len(set(challenges)) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Machine Identifier Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMachineIdentifier:
|
||||
r"""Tests for machine identifier generation."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_bytes(self):
|
||||
"""Machine identifier should be bytes."""
|
||||
identifier = _get_machine_identifier()
|
||||
|
||||
assert isinstance(identifier, bytes)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_deterministic(self):
|
||||
"""Machine identifier should be consistent across calls."""
|
||||
assert _get_machine_identifier() == _get_machine_identifier()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_contains_multiple_components(self):
|
||||
"""Identifier should contain pipe-separated components."""
|
||||
identifier = _get_machine_identifier()
|
||||
decoded = identifier.decode("utf-8")
|
||||
components = decoded.split("|")
|
||||
|
||||
# Should have at least username, hostname, platform.node, home dir
|
||||
assert len(components) >= 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encryption Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEncryption:
|
||||
r"""Tests for token encryption and decryption."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_derive_key_returns_fernet_compatible_key(self):
|
||||
"""Derived key should be 44 bytes (Fernet format)."""
|
||||
key = _derive_encryption_key()
|
||||
|
||||
assert isinstance(key, bytes)
|
||||
assert len(key) == 44
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_derive_key_is_deterministic(self):
|
||||
"""Derived key should be consistent for same machine."""
|
||||
assert _derive_encryption_key() == _derive_encryption_key()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_encrypt_decrypt_roundtrip(self):
|
||||
"""Data should survive encryption and decryption."""
|
||||
original = {
|
||||
"access_token": "sk-test-token-123",
|
||||
"refresh_token": "rt-refresh-456",
|
||||
"expires_at": 1234567890,
|
||||
"scope": "openai.api.read",
|
||||
}
|
||||
|
||||
encrypted = _encrypt_token_data(original)
|
||||
decrypted = _decrypt_token_data(encrypted)
|
||||
|
||||
assert decrypted == original
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_encrypt_returns_bytes(self):
|
||||
"""Encrypted output should be bytes."""
|
||||
encrypted = _encrypt_token_data({"access_token": "test"})
|
||||
|
||||
assert isinstance(encrypted, bytes)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_encrypted_data_differs_from_input(self):
|
||||
"""Encrypted data should not contain plaintext."""
|
||||
token = "my_secret_token"
|
||||
encrypted = _encrypt_token_data({"access_token": token})
|
||||
|
||||
assert token.encode() not in encrypted
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_decrypt_invalid_data_returns_none(self):
|
||||
"""Decrypting garbage data should return None, not raise."""
|
||||
assert _decrypt_token_data(b"not-valid-fernet-data") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_decrypt_corrupted_data_returns_none(self):
|
||||
"""Decrypting tampered data should return None."""
|
||||
encrypted = _encrypt_token_data({"access_token": "test"})
|
||||
corrupted = encrypted[:-10] + b"tampered!!"
|
||||
|
||||
assert _decrypt_token_data(corrupted) is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_decrypt_empty_data_returns_none(self):
|
||||
"""Decrypting empty data should return None."""
|
||||
assert _decrypt_token_data(b"") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Callback Handler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCallbackHandler:
|
||||
r"""Tests for OAuth callback HTTP handler."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_captures_authorization_code(self):
|
||||
"""Handler should capture auth code from callback URL."""
|
||||
path = "/auth/callback?code=auth_code_123&state=valid_state"
|
||||
with mock_callback_request(
|
||||
path, expected_state="valid_state"
|
||||
) as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
assert handler.server.auth_code == "auth_code_123"
|
||||
assert handler.response_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_captures_error_response(self):
|
||||
"""Handler should capture error from callback URL."""
|
||||
path = "/auth/callback?error=access_denied&error_description=User%20denied"
|
||||
with mock_callback_request(path) as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
assert handler.server.auth_error == "access_denied: User denied"
|
||||
assert handler.response_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_handles_missing_code(self):
|
||||
"""Handler should return 400 when code is missing."""
|
||||
with mock_callback_request("/auth/callback?state=xyz") as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
assert handler.response_code == 400
|
||||
assert handler.server.auth_code is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_escapes_html_in_error(self):
|
||||
"""Handler should escape HTML in error messages to prevent XSS."""
|
||||
path = "/auth/callback?error=<script>&error_description=<img>"
|
||||
with mock_callback_request(path) as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
output = handler.wfile.getvalue().decode()
|
||||
assert "<script>" not in output
|
||||
assert "<script>" in output or "script" not in output
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_rejects_mismatched_state(self):
|
||||
"""Handler should reject callback with mismatched state (CSRF protection)."""
|
||||
path = "/auth/callback?code=auth_code_123&state=wrong_state"
|
||||
with mock_callback_request(
|
||||
path, expected_state="correct_state"
|
||||
) as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
assert handler.response_code == 400
|
||||
assert handler.server.auth_code is None
|
||||
assert "state_mismatch" in handler.server.auth_error
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_accepts_matching_state(self):
|
||||
"""Handler should accept callback with matching state."""
|
||||
path = "/auth/callback?code=auth_code_123&state=my_state_value"
|
||||
with mock_callback_request(
|
||||
path, expected_state="my_state_value"
|
||||
) as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
assert handler.response_code == 200
|
||||
assert handler.server.auth_code == "auth_code_123"
|
||||
assert handler.server.auth_error is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_state_validation_when_not_expected(self):
|
||||
"""Handler should skip state validation if server has no expected_state."""
|
||||
path = "/auth/callback?code=auth_code_123"
|
||||
with mock_callback_request(path, expected_state=None) as handler:
|
||||
_CallbackHandler.do_GET(handler)
|
||||
|
||||
assert handler.response_code == 200
|
||||
assert handler.server.auth_code == "auth_code_123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token Operations Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenOperations:
|
||||
r"""Tests for CodexOAuthManager token save/load/clear."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_creates_directory_structure(self, clean_env):
|
||||
"""save_token should create parent directories."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nested_path = os.path.join(temp_dir, "a", "b", "c", "token.enc")
|
||||
with patch.object(
|
||||
CodexOAuthManager, "_token_path", return_value=nested_path
|
||||
):
|
||||
result = CodexOAuthManager.save_token({"access_token": "test"})
|
||||
|
||||
assert result is True
|
||||
assert os.path.exists(nested_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_sets_restrictive_permissions(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""Token file should have owner-only read/write permissions."""
|
||||
CodexOAuthManager.save_token({"access_token": "secret"})
|
||||
|
||||
file_stat = os.stat(temp_token_path)
|
||||
mode = stat.S_IMODE(file_stat.st_mode)
|
||||
|
||||
assert mode == (stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_does_not_set_environment_variable(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""save_token should NOT set OPENAI_API_KEY environment variable."""
|
||||
CodexOAuthManager.save_token({"access_token": "sk-test-key"})
|
||||
|
||||
# Token should only be stored in file, not in env var
|
||||
assert "OPENAI_API_KEY" not in os.environ
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_computes_expires_at_from_expires_in(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""save_token should convert expires_in to absolute expires_at."""
|
||||
before = int(time.time())
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
)
|
||||
after = int(time.time())
|
||||
|
||||
loaded = CodexOAuthManager.load_token()
|
||||
|
||||
assert "expires_at" in loaded
|
||||
assert "expires_in" not in loaded
|
||||
assert before + 3600 <= loaded["expires_at"] <= after + 3600
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_uses_default_lifetime_when_no_expiry(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""save_token should use default lifetime if no expiry provided."""
|
||||
before = int(time.time())
|
||||
CodexOAuthManager.save_token({"access_token": "test"})
|
||||
|
||||
loaded = CodexOAuthManager.load_token()
|
||||
expected = before + CODEX_TOKEN_DEFAULT_LIFETIME
|
||||
|
||||
assert loaded["expires_at"] >= expected
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_preserves_existing_expires_at(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""save_token should preserve explicit expires_at."""
|
||||
explicit_time = 9999999999
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": explicit_time,
|
||||
}
|
||||
)
|
||||
|
||||
loaded = CodexOAuthManager.load_token()
|
||||
|
||||
assert loaded["expires_at"] == explicit_time
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_returns_none_when_no_file(self, temp_token_path):
|
||||
"""load_token should return None if file doesn't exist."""
|
||||
assert CodexOAuthManager.load_token() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_returns_saved_data(self, temp_token_path, clean_env):
|
||||
"""load_token should return previously saved data."""
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "my-token",
|
||||
"refresh_token": "my-refresh",
|
||||
"scope": "openai.api.read",
|
||||
}
|
||||
)
|
||||
|
||||
loaded = CodexOAuthManager.load_token()
|
||||
|
||||
assert loaded["access_token"] == "my-token"
|
||||
assert loaded["refresh_token"] == "my-refresh"
|
||||
assert loaded["scope"] == "openai.api.read"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clear_removes_token_file(self, temp_token_path, clean_env):
|
||||
"""clear_token should delete the token file."""
|
||||
CodexOAuthManager.save_token({"access_token": "test"})
|
||||
assert os.path.exists(temp_token_path)
|
||||
|
||||
result = CodexOAuthManager.clear_token()
|
||||
|
||||
assert result is True
|
||||
assert not os.path.exists(temp_token_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clear_does_not_modify_environment_variable(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""clear_token should NOT modify OPENAI_API_KEY env var."""
|
||||
os.environ["OPENAI_API_KEY"] = "existing-key"
|
||||
|
||||
CodexOAuthManager.clear_token()
|
||||
|
||||
# Env var should remain untouched - it's managed separately
|
||||
assert os.environ.get("OPENAI_API_KEY") == "existing-key"
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_clear_succeeds_when_no_file(self, temp_token_path):
|
||||
"""clear_token should succeed even if no token file exists."""
|
||||
result = CodexOAuthManager.clear_token()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_token_info_returns_full_token_data(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""get_token_info should return complete stored token."""
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "token",
|
||||
"refresh_token": "refresh",
|
||||
}
|
||||
)
|
||||
|
||||
info = CodexOAuthManager.get_token_info()
|
||||
|
||||
assert info["access_token"] == "token"
|
||||
assert info["refresh_token"] == "refresh"
|
||||
assert "saved_at" in info
|
||||
assert "expires_at" in info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authentication Status Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthenticationStatus:
|
||||
r"""Tests for authentication status checking."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_authenticated_true_with_token_file(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_authenticated should return True when token file exists."""
|
||||
CodexOAuthManager.save_token({"access_token": "file-token"})
|
||||
|
||||
assert CodexOAuthManager.is_authenticated() is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_authenticated_false_with_only_env_var(self, temp_token_path):
|
||||
"""is_authenticated should return False when only env var is set (no Codex OAuth token)."""
|
||||
os.environ["OPENAI_API_KEY"] = "env-token"
|
||||
|
||||
try:
|
||||
# Codex OAuth status should not be affected by generic OPENAI_API_KEY
|
||||
assert CodexOAuthManager.is_authenticated() is False
|
||||
finally:
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_authenticated_false_when_nothing_configured(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_authenticated should return False with no credentials."""
|
||||
assert CodexOAuthManager.is_authenticated() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_access_token_prefers_file_over_env(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""get_access_token should prefer file token over env var."""
|
||||
CodexOAuthManager.save_token({"access_token": "file-token"})
|
||||
os.environ["OPENAI_API_KEY"] = "env-token"
|
||||
|
||||
token = CodexOAuthManager.get_access_token()
|
||||
|
||||
assert token == "file-token"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_access_token_returns_none_without_oauth_token(
|
||||
self, temp_token_path
|
||||
):
|
||||
"""get_access_token should return None when no Codex OAuth token exists."""
|
||||
os.environ["OPENAI_API_KEY"] = "env-fallback"
|
||||
|
||||
try:
|
||||
# Should not fall back to env var; Codex OAuth token is separate
|
||||
assert CodexOAuthManager.get_access_token() is None
|
||||
finally:
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_access_token_returns_none_when_nothing(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""get_access_token should return None with no credentials."""
|
||||
assert CodexOAuthManager.get_access_token() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token Expiry Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenExpiry:
|
||||
r"""Tests for token expiration checking."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expired_false_when_no_token(self, temp_token_path):
|
||||
"""is_token_expired should return False if no token exists."""
|
||||
assert CodexOAuthManager.is_token_expired() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expired_false_when_token_valid(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_token_expired should return False for valid token."""
|
||||
future = int(time.time()) + 3600
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": future,
|
||||
}
|
||||
)
|
||||
|
||||
assert CodexOAuthManager.is_token_expired() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expired_true_when_token_expired(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_token_expired should return True for expired token."""
|
||||
past = int(time.time()) - 100
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": past,
|
||||
}
|
||||
)
|
||||
|
||||
assert CodexOAuthManager.is_token_expired() is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expired_false_when_no_expires_at(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_token_expired should return False if expires_at missing."""
|
||||
# Directly write token without expires_at
|
||||
token_data = {"access_token": "test", "saved_at": int(time.time())}
|
||||
encrypted = _encrypt_token_data(token_data)
|
||||
os.makedirs(os.path.dirname(temp_token_path), exist_ok=True)
|
||||
with open(temp_token_path, "wb") as f:
|
||||
f.write(encrypted)
|
||||
|
||||
assert CodexOAuthManager.is_token_expired() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expiring_soon_false_when_no_token(self, temp_token_path):
|
||||
"""is_token_expiring_soon should return False if no token."""
|
||||
assert CodexOAuthManager.is_token_expiring_soon() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expiring_soon_true_within_threshold(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_token_expiring_soon should return True within threshold."""
|
||||
soon = int(time.time()) + CODEX_TOKEN_REFRESH_THRESHOLD - 60
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": soon,
|
||||
}
|
||||
)
|
||||
|
||||
assert CodexOAuthManager.is_token_expiring_soon() is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_is_expiring_soon_false_with_plenty_of_time(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""is_token_expiring_soon should return False with ample time."""
|
||||
future = int(time.time()) + 3600
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": future,
|
||||
}
|
||||
)
|
||||
|
||||
assert CodexOAuthManager.is_token_expiring_soon() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token Refresh Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenRefresh:
|
||||
r"""Tests for token refresh functionality."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_returns_false_when_no_token(self, temp_token_path):
|
||||
"""refresh_token_if_needed should return False with no token."""
|
||||
assert CodexOAuthManager.refresh_token_if_needed() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_returns_true_when_not_expiring(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""refresh_token_if_needed should return True if not expiring."""
|
||||
future = int(time.time()) + 3600
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": future,
|
||||
}
|
||||
)
|
||||
|
||||
assert CodexOAuthManager.refresh_token_if_needed() is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_returns_false_without_refresh_token(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""refresh should return False if no refresh token available."""
|
||||
soon = int(time.time()) + CODEX_TOKEN_REFRESH_THRESHOLD - 60
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"expires_at": soon,
|
||||
}
|
||||
)
|
||||
|
||||
assert CodexOAuthManager.refresh_token_if_needed() is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_calls_token_endpoint(self, temp_token_path, clean_env):
|
||||
"""refresh should call OpenAI token endpoint with correct params."""
|
||||
soon = int(time.time()) + CODEX_TOKEN_REFRESH_THRESHOLD - 60
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "refresh-123",
|
||||
"expires_at": soon,
|
||||
}
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("requests.post", return_value=mock_response) as mock_post:
|
||||
result = CodexOAuthManager.refresh_token_if_needed()
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once()
|
||||
|
||||
call_kwargs = mock_post.call_args
|
||||
assert call_kwargs[1]["data"]["grant_type"] == "refresh_token"
|
||||
assert call_kwargs[1]["data"]["client_id"] == CODEX_CLIENT_ID
|
||||
assert call_kwargs[1]["data"]["refresh_token"] == "refresh-123"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_saves_new_token(self, temp_token_path, clean_env):
|
||||
"""refresh should save the new token after successful refresh."""
|
||||
soon = int(time.time()) + CODEX_TOKEN_REFRESH_THRESHOLD - 60
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "refresh-123",
|
||||
"expires_at": soon,
|
||||
}
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "brand-new-token",
|
||||
"expires_in": 7200,
|
||||
"refresh_token": "new-refresh-456",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("requests.post", return_value=mock_response):
|
||||
CodexOAuthManager.refresh_token_if_needed()
|
||||
|
||||
loaded = CodexOAuthManager.load_token()
|
||||
assert loaded["access_token"] == "brand-new-token"
|
||||
assert loaded["refresh_token"] == "new-refresh-456"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_returns_false_on_api_error(
|
||||
self, temp_token_path, clean_env
|
||||
):
|
||||
"""refresh should return False if API call fails."""
|
||||
soon = int(time.time()) + CODEX_TOKEN_REFRESH_THRESHOLD - 60
|
||||
CodexOAuthManager.save_token(
|
||||
{
|
||||
"access_token": "test",
|
||||
"refresh_token": "refresh",
|
||||
"expires_at": soon,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("requests.post", side_effect=Exception("Network error")):
|
||||
result = CodexOAuthManager.refresh_token_if_needed()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Background Auth Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBackgroundAuth:
|
||||
r"""Tests for background OAuth flow."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_start_returns_authorizing(self):
|
||||
"""start_background_auth should return 'authorizing'."""
|
||||
with patch("app.utils.codex_oauth.oauth_state_manager") as mock_mgr:
|
||||
mock_state = MagicMock()
|
||||
mock_state.status = "pending"
|
||||
mock_state.is_cancelled.return_value = False
|
||||
mock_mgr.get_state.return_value = None
|
||||
mock_mgr.create_state.return_value = mock_state
|
||||
|
||||
with (
|
||||
patch("webbrowser.open"),
|
||||
patch("app.utils.codex_oauth.HTTPServer") as mock_server,
|
||||
):
|
||||
mock_server.return_value.server_address = ("127.0.0.1", 9999)
|
||||
|
||||
result = CodexOAuthManager.start_background_auth()
|
||||
|
||||
assert result == "authorizing"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_start_cancels_existing_pending_flow(self):
|
||||
"""start_background_auth should cancel any existing flow."""
|
||||
with patch("app.utils.codex_oauth.oauth_state_manager") as mock_mgr:
|
||||
old_state = MagicMock()
|
||||
old_state.status = "authorizing"
|
||||
old_state.server = MagicMock()
|
||||
|
||||
new_state = MagicMock()
|
||||
new_state.status = "pending"
|
||||
new_state.is_cancelled.return_value = False
|
||||
|
||||
mock_mgr.get_state.return_value = old_state
|
||||
mock_mgr.create_state.return_value = new_state
|
||||
|
||||
with (
|
||||
patch("webbrowser.open"),
|
||||
patch("app.utils.codex_oauth.HTTPServer") as mock_server,
|
||||
):
|
||||
mock_server.return_value.server_address = ("127.0.0.1", 9999)
|
||||
|
||||
CodexOAuthManager.start_background_auth()
|
||||
|
||||
old_state.cancel.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_start_creates_new_oauth_state(self):
|
||||
"""start_background_auth should create a new OAuth state."""
|
||||
with patch("app.utils.codex_oauth.oauth_state_manager") as mock_mgr:
|
||||
mock_state = MagicMock()
|
||||
mock_state.status = "pending"
|
||||
mock_state.is_cancelled.return_value = False
|
||||
mock_mgr.get_state.return_value = None
|
||||
mock_mgr.create_state.return_value = mock_state
|
||||
|
||||
with (
|
||||
patch("webbrowser.open"),
|
||||
patch("app.utils.codex_oauth.HTTPServer") as mock_server,
|
||||
):
|
||||
mock_server.return_value.server_address = ("127.0.0.1", 9999)
|
||||
|
||||
CodexOAuthManager.start_background_auth()
|
||||
|
||||
mock_mgr.create_state.assert_called_once_with("codex")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstants:
|
||||
r"""Tests for module constants."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_lifetime_is_one_hour(self):
|
||||
"""Default token lifetime should be 3600 seconds (1 hour)."""
|
||||
assert CODEX_TOKEN_DEFAULT_LIFETIME == 3600
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_refresh_threshold_is_five_minutes(self):
|
||||
"""Refresh threshold should be 300 seconds (5 minutes)."""
|
||||
assert CODEX_TOKEN_REFRESH_THRESHOLD == 300
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_client_id_is_codex_cli_public_id(self):
|
||||
"""Client ID should be the public Codex CLI client ID."""
|
||||
assert CODEX_CLIENT_ID == "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
Loading…
Add table
Add a link
Reference in a new issue