diff --git a/backend/app/controller/codex_controller.py b/backend/app/controller/codex_controller.py deleted file mode 100644 index 67147c684..000000000 --- a/backend/app/controller/codex_controller.py +++ /dev/null @@ -1,234 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -import logging -import time - -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from app.utils.codex_oauth import CodexOAuthManager -from app.utils.oauth_state_manager import oauth_state_manager - - -class CodexTokenRequest(BaseModel): - r"""Request model for saving Codex/OpenAI API token.""" - - access_token: str - expires_in: int | None = None - - -logger = logging.getLogger("codex_controller") -router = APIRouter() - - -@router.post("/codex/connect", name="connect codex") -async def connect_codex(): - r"""Connect to Codex/OpenAI via OAuth PKCE flow. - - Initiates or completes the Codex OAuth authorization flow - to obtain an OpenAI API key. - - Returns: - Connection result with access token and provider info - """ - try: - if CodexOAuthManager.is_authenticated(): - if CodexOAuthManager.is_token_expired(): - # Try refreshing first - if CodexOAuthManager.refresh_token_if_needed(): - return { - "success": True, - "message": "Codex token refreshed successfully", - "toolkit_name": "CodexOAuthManager", - "access_token": CodexOAuthManager.get_access_token(), - "provider_name": "openai", - "endpoint_url": "https://api.openai.com/v1", - } - # Refresh failed, start new auth - logger.info( - "Codex token expired and refresh failed, starting re-auth" - ) - CodexOAuthManager.start_background_auth() - return { - "success": False, - "status": "authorizing", - "message": "Token expired. Browser should" - " open for re-authorization.", - "toolkit_name": "CodexOAuthManager", - "requires_auth": True, - } - - return { - "success": True, - "message": "Codex/OpenAI is already authenticated", - "toolkit_name": "CodexOAuthManager", - "access_token": CodexOAuthManager.get_access_token(), - "provider_name": "openai", - "endpoint_url": "https://api.openai.com/v1", - } - else: - logger.info("No Codex credentials found, starting OAuth flow") - CodexOAuthManager.start_background_auth() - return { - "success": False, - "status": "authorizing", - "message": "Authorization required. Browser" - " should open automatically.", - "toolkit_name": "CodexOAuthManager", - "requires_auth": True, - } - except Exception as e: - logger.error(f"Failed to connect Codex: {e}") - raise HTTPException( - status_code=500, detail=f"Failed to connect Codex: {str(e)}" - ) - - -@router.post("/codex/disconnect", name="disconnect codex") -async def disconnect_codex(): - r"""Disconnect Codex/OpenAI and clean up authentication data. - - Cancels any active OAuth flow and clears stored tokens. - - Returns: - Disconnection result - """ - try: - # Cancel any active OAuth flow - state = oauth_state_manager.get_state("codex") - if state and state.status in ["pending", "authorizing"]: - state.cancel() - if hasattr(state, "server") and state.server: - try: - state.server.shutdown() - except Exception: - pass - oauth_state_manager._states.pop("codex", None) - - success = CodexOAuthManager.clear_token() - - if success: - return { - "success": True, - "message": ( - "Successfully disconnected Codex" - " and cleaned up" - " authentication tokens" - ), - } - else: - return { - "success": True, - "message": "Disconnected Codex (no tokens found to clean up)", - } - except Exception as e: - logger.error(f"Failed to disconnect Codex: {e}") - raise HTTPException( - status_code=500, detail=f"Failed to disconnect Codex: {str(e)}" - ) - - -@router.post("/codex/save-token", name="save codex token") -async def save_codex_token(token_request: CodexTokenRequest): - r"""Save Codex/OpenAI API token (manual API key entry fallback). - - Args: - token_request: Token data containing access_token - and optionally expires_in - - Returns: - Save result - """ - try: - token_data = token_request.model_dump(exclude_none=True) - token_data["manual"] = True - - success = CodexOAuthManager.save_token(token_data) - - if success: - return { - "success": True, - "message": "Codex token saved successfully", - } - else: - raise HTTPException( - status_code=500, detail="Failed to save Codex token" - ) - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to save Codex token: {e}") - raise HTTPException( - status_code=500, detail=f"Failed to save token: {str(e)}" - ) - - -@router.get("/codex/status", name="get codex status") -async def get_codex_status(): - r"""Get current Codex/OpenAI authentication status and token info. - - Returns: - Status information including authentication state and token expiry - """ - try: - is_authenticated = CodexOAuthManager.is_authenticated() - - if not is_authenticated: - return { - "authenticated": False, - "status": "not_configured", - "message": "Codex not configured. OAuth or API key required.", - } - - token_info = CodexOAuthManager.get_token_info() - is_expired = CodexOAuthManager.is_token_expired() - is_expiring_soon = CodexOAuthManager.is_token_expiring_soon() - - result = { - "authenticated": True, - "status": "expired" - if is_expired - else ("expiring_soon" if is_expiring_soon else "valid"), - } - - if token_info: - if token_info.get("expires_at"): - current_time = int(time.time()) - expires_at = token_info["expires_at"] - seconds_remaining = max(0, expires_at - current_time) - result["expires_at"] = expires_at - result["seconds_remaining"] = seconds_remaining - - if token_info.get("saved_at"): - result["saved_at"] = token_info["saved_at"] - - if token_info.get("manual"): - result["manual"] = True - - if is_expired: - result["message"] = "Token has expired. Please re-authenticate." - elif is_expiring_soon: - result["message"] = ( - "Token is expiring soon. Consider re-authenticating." - ) - else: - result["message"] = "Codex/OpenAI is connected and token is valid." - - return result - except Exception as e: - logger.error(f"Failed to get Codex status: {e}") - raise HTTPException( - status_code=500, detail=f"Failed to get status: {str(e)}" - ) diff --git a/backend/app/router.py b/backend/app/router.py index 22c23369c..f257dde08 100644 --- a/backend/app/router.py +++ b/backend/app/router.py @@ -23,7 +23,6 @@ from fastapi import FastAPI from app.controller import ( chat_controller, - codex_controller, health_controller, model_controller, task_controller, @@ -72,11 +71,6 @@ def register_routers(app: FastAPI, prefix: str = "") -> None: "tags": ["tool"], "description": "Tool installation and management", }, - { - "router": codex_controller.router, - "tags": ["codex"], - "description": "Codex OAuth provider authentication", - }, ] for config in routers_config: diff --git a/backend/app/utils/codex_oauth.py b/backend/app/utils/codex_oauth.py deleted file mode 100644 index ed2e96120..000000000 --- a/backend/app/utils/codex_oauth.py +++ /dev/null @@ -1,548 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -r"""OpenAI Codex OAuth manager. - -Handles Authorization Code + PKCE flow using Codex CLI's public client_id. -The resulting access token is stored in an encrypted file and used -independently of the OPENAI_API_KEY environment variable. -""" - -import base64 -import getpass -import hashlib -import json -import logging -import os -import platform -import secrets -import socket -import stat -import threading -import time -import webbrowser -from html import escape as html_escape -from http.server import BaseHTTPRequestHandler, HTTPServer -from urllib.parse import parse_qs, urlencode, urlparse - -import requests -from cryptography.fernet import Fernet, InvalidToken -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from filelock import FileLock - -from app.utils.oauth_state_manager import oauth_state_manager - -logger = logging.getLogger("codex_oauth") - -# OpenAI / Codex OAuth constants -# Fixed public client_id from the Codex CLI (not a secret). -CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" -CODEX_AUTH_URL = "https://auth.openai.com/oauth/authorize" -CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token" -# Fixed callback port used by Codex CLI -CODEX_CALLBACK_PORT = 1455 - -# Token storage path -CODEX_TOKEN_DIR = os.path.join( - os.path.expanduser("~"), ".eigent", "tokens", "codex" -) -CODEX_TOKEN_PATH = os.path.join(CODEX_TOKEN_DIR, "codex_token.enc") - -# Token lifetime defaults (seconds) -CODEX_TOKEN_DEFAULT_LIFETIME = 3600 # 1 hour -CODEX_TOKEN_REFRESH_THRESHOLD = 300 # 5 minutes before expiry - - -def _get_machine_identifier() -> bytes: - r"""Get a machine-specific identifier for key derivation. - - Combines multiple machine-specific values to create a stable identifier - that is unique to this machine but consistent across restarts. - - Returns: - Machine identifier as bytes. - """ - components = [ - getpass.getuser(), - socket.gethostname(), - platform.node(), - # Add home directory path for additional uniqueness - os.path.expanduser("~"), - ] - - # Try to get machine-id on Linux - machine_id_paths = [ - "/etc/machine-id", - "/var/lib/dbus/machine-id", - ] - for path in machine_id_paths: - try: - with open(path) as f: - components.append(f.read().strip()) - break - except (FileNotFoundError, PermissionError): - continue - - return "|".join(components).encode("utf-8") - - -def _derive_encryption_key() -> bytes: - r"""Derive an encryption key from machine-specific identifiers. - - Uses PBKDF2 to derive a Fernet-compatible key from machine identifiers. - This ties the encryption to the specific machine without storing a key file. - - Returns: - The Fernet encryption key as bytes. - """ - # Fixed salt for this application (not secret, just ensures uniqueness) - # The security comes from the machine-specific identifier - app_salt = b"eigent-codex-token-encryption-v1" - - machine_id = _get_machine_identifier() - - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=app_salt, - iterations=100_000, - ) - - # Derive a 32-byte key and encode it as base64 for Fernet - derived_key = kdf.derive(machine_id) - return base64.urlsafe_b64encode(derived_key) - - -def _encrypt_token_data(token_data: dict) -> bytes: - r"""Encrypt token data using Fernet symmetric encryption. - - Args: - token_data: Dictionary containing token information. - - Returns: - Encrypted bytes. - """ - key = _derive_encryption_key() - fernet = Fernet(key) - json_bytes = json.dumps(token_data).encode("utf-8") - return fernet.encrypt(json_bytes) - - -def _decrypt_token_data(encrypted_data: bytes) -> dict | None: - r"""Decrypt token data. - - Args: - encrypted_data: Encrypted bytes from file. - - Returns: - Decrypted token dictionary, or None if decryption fails. - """ - try: - key = _derive_encryption_key() - fernet = Fernet(key) - decrypted = fernet.decrypt(encrypted_data) - return json.loads(decrypted.decode("utf-8")) - except (InvalidToken, json.JSONDecodeError) as e: - logger.warning("Failed to decrypt token data: %s", e) - return None - - -def _generate_pkce_pair() -> tuple[str, str]: - r"""Generate a PKCE code_verifier and S256 code_challenge. - - Returns: - Tuple of (code_verifier, code_challenge). - """ - code_verifier = secrets.token_urlsafe(64) - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - code_challenge = ( - base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") - ) - return code_verifier, code_challenge - - -class _CallbackHandler(BaseHTTPRequestHandler): - r"""HTTP handler that captures the OAuth callback code.""" - - def do_GET(self): - parsed = urlparse(self.path) - params = parse_qs(parsed.query) - - # Validate state parameter to prevent CSRF attacks - received_state = params.get("state", [None])[0] - expected_state = getattr(self.server, "expected_state", None) - - if expected_state and received_state != expected_state: - self.server.auth_error = "state_mismatch: Invalid state parameter" - self.send_response(400) - self.send_header("Content-Type", "text/html") - self.end_headers() - self.wfile.write( - b"

Authorization failed

" - b"

Invalid state parameter. Possible CSRF attack.

" - b"" - ) - return - - if "code" in params: - self.server.auth_code = params["code"][0] - self.send_response(200) - self.send_header("Content-Type", "text/html") - self.end_headers() - self.wfile.write( - b"" - b"

Authorization successful!

" - b"

You can close this window and return to Eigent.

" - b"" - ) - elif "error" in params: - error = params.get("error", ["unknown"])[0] - desc = params.get("error_description", [""])[0] - self.server.auth_error = f"{error}: {desc}" - self.send_response(400) - self.send_header("Content-Type", "text/html") - self.end_headers() - # Escape HTML to prevent XSS from query parameters - self.wfile.write( - f"

Authorization failed

" - f"

{html_escape(error)}: {html_escape(desc)}

" - f"".encode() - ) - else: - self.send_response(400) - self.send_header("Content-Type", "text/html") - self.end_headers() - self.wfile.write( - b"

" - b"Missing authorization code" - b"

" - ) - - def log_message(self, format, *args): - logger.debug("Codex callback server: %s", format % args) - - -class CodexOAuthManager: - r"""Manages OpenAI Codex OAuth token lifecycle.""" - - @staticmethod - def _token_path() -> str: - return CODEX_TOKEN_PATH - - @classmethod - def save_token(cls, token_data: dict) -> bool: - r"""Save token data to disk with encryption. - - Args: - token_data: Dictionary with at least ``access_token``. - - Returns: - True on success. - """ - path = cls._token_path() - token_data = token_data.copy() - try: - if "saved_at" not in token_data: - token_data["saved_at"] = int(time.time()) - - # Compute absolute expiry from the relative expires_in value - # (if present), then discard expires_in so we only store the - # absolute timestamp. - if "expires_at" not in token_data: - expires_in = token_data.pop( - "expires_in", CODEX_TOKEN_DEFAULT_LIFETIME - ) - token_data["expires_at"] = token_data["saved_at"] + expires_in - else: - token_data.pop("expires_in", None) - - os.makedirs(os.path.dirname(path), exist_ok=True) - lock = FileLock(path + ".lock") - - # Encrypt token data before saving - encrypted_data = _encrypt_token_data(token_data) - - with lock, open(path, "wb") as f: - f.write(encrypted_data) - - # Set restrictive permissions on token file - os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) - - logger.info("Saved encrypted Codex token to %s", path) - - return True - except Exception as e: - logger.error("Failed to save Codex token: %s", e) - return False - - @classmethod - def load_token(cls) -> dict | None: - r"""Load and decrypt token data from disk.""" - path = cls._token_path() - if os.path.exists(path): - try: - lock = FileLock(path + ".lock") - with lock, open(path, "rb") as f: - encrypted_data = f.read() - return _decrypt_token_data(encrypted_data) - except Exception as e: - logger.warning("Failed to load token: %s", e) - return None - - @classmethod - def clear_token(cls) -> bool: - r"""Remove stored token file.""" - path = cls._token_path() - try: - if os.path.exists(path): - os.remove(path) - logger.info("Removed Codex token file: %s", path) - - token_dir = os.path.dirname(path) - if os.path.exists(token_dir) and not os.listdir(token_dir): - os.rmdir(token_dir) - - return True - except Exception as e: - logger.error("Failed to clear Codex token: %s", e) - return False - - @classmethod - def is_authenticated(cls) -> bool: - r"""Return True if a Codex OAuth token is available.""" - token = cls.load_token() - return bool(token and token.get("access_token")) - - @classmethod - def is_token_expired(cls) -> bool: - r"""Return True if the stored token has expired.""" - token = cls.load_token() - if not token: - return False - expires_at = token.get("expires_at") - if not expires_at: - return False - return int(time.time()) >= expires_at - - @classmethod - def is_token_expiring_soon(cls) -> bool: - r"""Return True if the token expires within the refresh threshold.""" - token = cls.load_token() - if not token: - return False - expires_at = token.get("expires_at") - if not expires_at: - return False - return (expires_at - int(time.time())) < CODEX_TOKEN_REFRESH_THRESHOLD - - @classmethod - def get_access_token(cls) -> str | None: - r"""Return the current Codex OAuth access token.""" - token = cls.load_token() - if token and token.get("access_token"): - return token["access_token"] - return None - - @classmethod - def get_token_info(cls) -> dict | None: - r"""Return stored token metadata.""" - return cls.load_token() - - @classmethod - def refresh_token_if_needed(cls) -> bool: - r"""Attempt to refresh the token if it has a refresh_token. - - Returns: - True if refreshed or not needed, False on failure. - """ - token = cls.load_token() - if not token: - return False - - if not cls.is_token_expiring_soon(): - return True - - refresh_token = token.get("refresh_token") - if not refresh_token: - return False - - try: - resp = requests.post( - CODEX_TOKEN_URL, - data={ - "grant_type": "refresh_token", - "client_id": CODEX_CLIENT_ID, - "refresh_token": refresh_token, - }, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - timeout=30, - ) - resp.raise_for_status() - new_token = resp.json() - - # Merge with existing data - token.update( - { - "access_token": new_token["access_token"], - "expires_in": new_token.get( - "expires_in", CODEX_TOKEN_DEFAULT_LIFETIME - ), - "saved_at": int(time.time()), - } - ) - if new_token.get("refresh_token"): - token["refresh_token"] = new_token["refresh_token"] - token.pop("expires_at", None) - - return cls.save_token(token) - except Exception as e: - logger.error("Failed to refresh Codex token: %s", e) - return False - - # ------------------------------------------------------------------ - # Background OAuth flow - # ------------------------------------------------------------------ - - @classmethod - def start_background_auth(cls) -> str: - r"""Launch the PKCE OAuth flow in a background thread. - - Returns: - ``"authorizing"`` immediately. - """ - # Cancel any existing flow - old_state = oauth_state_manager.get_state("codex") - if old_state and old_state.status in ["pending", "authorizing"]: - old_state.cancel() - if hasattr(old_state, "server") and old_state.server: - try: - old_state.server.shutdown() - except Exception: - pass - - state = oauth_state_manager.create_state("codex") - - def _auth_flow(): - try: - state.status = "authorizing" - oauth_state_manager.update_status("codex", "authorizing") - - code_verifier, code_challenge = _generate_pkce_pair() - - # Generate state parameter to prevent CSRF attacks - oauth_state = secrets.token_urlsafe(32) - - # Start localhost callback server on fixed port 1455 (Codex standard) - server = HTTPServer( - ("127.0.0.1", CODEX_CALLBACK_PORT), _CallbackHandler - ) - server.auth_code = None - server.auth_error = None - server.expected_state = oauth_state - state.server = server - - redirect_uri = ( - f"http://localhost:{CODEX_CALLBACK_PORT}/auth/callback" - ) - - params = urlencode( - { - "response_type": "code", - "client_id": CODEX_CLIENT_ID, - "redirect_uri": redirect_uri, - "scope": "openid profile email offline_access", - "code_challenge": code_challenge, - "code_challenge_method": "S256", - "state": oauth_state, - "id_token_add_organizations": "true", - "codex_cli_simplified_flow": "true", - } - ) - auth_url = f"{CODEX_AUTH_URL}?{params}" - - if state.is_cancelled(): - server.server_close() - return - - logger.info( - "Opening browser for Codex OAuth on port %d", - CODEX_CALLBACK_PORT, - ) - webbrowser.open(auth_url) - - # Wait for the callback (single request) - server.handle_request() - - if state.is_cancelled(): - server.server_close() - return - - if server.auth_error: - raise ValueError(server.auth_error) - - if not server.auth_code: - raise ValueError("No authorization code received") - - auth_code = server.auth_code - server.server_close() - - # Exchange code for token - token_resp = requests.post( - CODEX_TOKEN_URL, - data={ - "grant_type": "authorization_code", - "client_id": CODEX_CLIENT_ID, - "code": auth_code, - "redirect_uri": redirect_uri, - "code_verifier": code_verifier, - }, - headers={ - "Content-Type": "application/x-www-form-urlencoded" - }, - timeout=30, - ) - token_resp.raise_for_status() - token_data = token_resp.json() - - if state.is_cancelled(): - return - - cls.save_token(token_data) - - oauth_state_manager.update_status( - "codex", "success", result=token_data - ) - logger.info("Codex OAuth authorization successful") - - except Exception as e: - if state.is_cancelled(): - oauth_state_manager.update_status("codex", "cancelled") - else: - logger.error("Codex OAuth failed: %s", e) - oauth_state_manager.update_status( - "codex", "failed", error=str(e) - ) - finally: - state.server = None - - thread = threading.Thread( - target=_auth_flow, - daemon=True, - name=f"Codex-OAuth-{state.started_at.timestamp()}", - ) - state.thread = thread - thread.start() - - logger.info("Started background Codex OAuth authorization") - return "authorizing" diff --git a/backend/tests/app/utils/test_codex_oauth.py b/backend/tests/app/utils/test_codex_oauth.py deleted file mode 100644 index 004283c98..000000000 --- a/backend/tests/app/utils/test_codex_oauth.py +++ /dev/null @@ -1,875 +0,0 @@ -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= - -r"""Tests for the Codex OAuth manager.""" - -import io -import os -import stat -import tempfile -import time -from collections.abc import Generator -from contextlib import contextmanager -from unittest.mock import MagicMock, patch - -import pytest - -from app.utils.codex_oauth import ( - CODEX_CLIENT_ID, - CODEX_TOKEN_DEFAULT_LIFETIME, - CODEX_TOKEN_REFRESH_THRESHOLD, - CodexOAuthManager, - _CallbackHandler, - _decrypt_token_data, - _derive_encryption_key, - _encrypt_token_data, - _generate_pkce_pair, - _get_machine_identifier, -) - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def temp_token_path() -> Generator[str, None, None]: - r"""Create a temporary token path and patch CodexOAuthManager to use it.""" - with tempfile.TemporaryDirectory() as temp_dir: - token_path = os.path.join(temp_dir, "codex_token.enc") - with patch.object( - CodexOAuthManager, "_token_path", return_value=token_path - ): - yield token_path - - -@pytest.fixture -def clean_env() -> Generator[None, None, None]: - r"""Ensure OPENAI_API_KEY is cleaned up before and after test.""" - original = os.environ.pop("OPENAI_API_KEY", None) - yield - os.environ.pop("OPENAI_API_KEY", None) - if original is not None: - os.environ["OPENAI_API_KEY"] = original - - -@contextmanager -def mock_callback_request(path: str, expected_state: str | None = None): - r"""Create a mock HTTP request for _CallbackHandler testing.""" - handler = MagicMock(spec=_CallbackHandler) - handler.path = path - handler.wfile = io.BytesIO() - - # Track response - handler.response_code = None - handler.headers_sent = {} - - def send_response(code): - handler.response_code = code - - def send_header(name, value): - handler.headers_sent[name] = value - - def end_headers(): - pass - - handler.send_response = send_response - handler.send_header = send_header - handler.end_headers = end_headers - - # Create mock server - handler.server = MagicMock() - handler.server.auth_code = None - handler.server.auth_error = None - handler.server.expected_state = expected_state - - yield handler - - -# --------------------------------------------------------------------------- -# PKCE Generation Tests -# --------------------------------------------------------------------------- - - -class TestPKCEGeneration: - r"""Tests for PKCE code verifier and challenge generation.""" - - @pytest.mark.unit - def test_returns_tuple_of_strings(self): - """PKCE pair should be a tuple of two strings.""" - verifier, challenge = _generate_pkce_pair() - - assert isinstance(verifier, str) - assert isinstance(challenge, str) - - @pytest.mark.unit - def test_verifier_meets_minimum_length(self): - """Code verifier should be at least 43 characters (RFC 7636).""" - verifier, _ = _generate_pkce_pair() - - assert len(verifier) >= 43 - - @pytest.mark.unit - def test_challenge_is_base64url_without_padding(self): - """Code challenge should be valid base64url without padding.""" - _, challenge = _generate_pkce_pair() - - # Base64url should not contain +, /, or = - assert "+" not in challenge - assert "/" not in challenge - assert "=" not in challenge - - @pytest.mark.unit - def test_generates_unique_values(self): - """Each call should generate cryptographically unique values.""" - pairs = [_generate_pkce_pair() for _ in range(10)] - verifiers = [p[0] for p in pairs] - challenges = [p[1] for p in pairs] - - assert len(set(verifiers)) == 10 - assert len(set(challenges)) == 10 - - -# --------------------------------------------------------------------------- -# Machine Identifier Tests -# --------------------------------------------------------------------------- - - -class TestMachineIdentifier: - r"""Tests for machine identifier generation.""" - - @pytest.mark.unit - def test_returns_bytes(self): - """Machine identifier should be bytes.""" - identifier = _get_machine_identifier() - - assert isinstance(identifier, bytes) - - @pytest.mark.unit - def test_is_deterministic(self): - """Machine identifier should be consistent across calls.""" - assert _get_machine_identifier() == _get_machine_identifier() - - @pytest.mark.unit - def test_contains_multiple_components(self): - """Identifier should contain pipe-separated components.""" - identifier = _get_machine_identifier() - decoded = identifier.decode("utf-8") - components = decoded.split("|") - - # Should have at least username, hostname, platform.node, home dir - assert len(components) >= 4 - - -# --------------------------------------------------------------------------- -# Encryption Tests -# --------------------------------------------------------------------------- - - -class TestEncryption: - r"""Tests for token encryption and decryption.""" - - @pytest.mark.unit - def test_derive_key_returns_fernet_compatible_key(self): - """Derived key should be 44 bytes (Fernet format).""" - key = _derive_encryption_key() - - assert isinstance(key, bytes) - assert len(key) == 44 - - @pytest.mark.unit - def test_derive_key_is_deterministic(self): - """Derived key should be consistent for same machine.""" - assert _derive_encryption_key() == _derive_encryption_key() - - @pytest.mark.unit - def test_encrypt_decrypt_roundtrip(self): - """Data should survive encryption and decryption.""" - original = { - "access_token": "sk-test-token-123", - "refresh_token": "rt-refresh-456", - "expires_at": 1234567890, - "scope": "openai.api.read", - } - - encrypted = _encrypt_token_data(original) - decrypted = _decrypt_token_data(encrypted) - - assert decrypted == original - - @pytest.mark.unit - def test_encrypt_returns_bytes(self): - """Encrypted output should be bytes.""" - encrypted = _encrypt_token_data({"access_token": "test"}) - - assert isinstance(encrypted, bytes) - assert len(encrypted) > 0 - - @pytest.mark.unit - def test_encrypted_data_differs_from_input(self): - """Encrypted data should not contain plaintext.""" - token = "my_secret_token" - encrypted = _encrypt_token_data({"access_token": token}) - - assert token.encode() not in encrypted - - @pytest.mark.unit - def test_decrypt_invalid_data_returns_none(self): - """Decrypting garbage data should return None, not raise.""" - assert _decrypt_token_data(b"not-valid-fernet-data") is None - - @pytest.mark.unit - def test_decrypt_corrupted_data_returns_none(self): - """Decrypting tampered data should return None.""" - encrypted = _encrypt_token_data({"access_token": "test"}) - corrupted = encrypted[:-10] + b"tampered!!" - - assert _decrypt_token_data(corrupted) is None - - @pytest.mark.unit - def test_decrypt_empty_data_returns_none(self): - """Decrypting empty data should return None.""" - assert _decrypt_token_data(b"") is None - - -# --------------------------------------------------------------------------- -# Callback Handler Tests -# --------------------------------------------------------------------------- - - -class TestCallbackHandler: - r"""Tests for OAuth callback HTTP handler.""" - - @pytest.mark.unit - def test_captures_authorization_code(self): - """Handler should capture auth code from callback URL.""" - path = "/auth/callback?code=auth_code_123&state=valid_state" - with mock_callback_request( - path, expected_state="valid_state" - ) as handler: - _CallbackHandler.do_GET(handler) - - assert handler.server.auth_code == "auth_code_123" - assert handler.response_code == 200 - - @pytest.mark.unit - def test_captures_error_response(self): - """Handler should capture error from callback URL.""" - path = "/auth/callback?error=access_denied&error_description=User%20denied" - with mock_callback_request(path) as handler: - _CallbackHandler.do_GET(handler) - - assert handler.server.auth_error == "access_denied: User denied" - assert handler.response_code == 400 - - @pytest.mark.unit - def test_handles_missing_code(self): - """Handler should return 400 when code is missing.""" - with mock_callback_request("/auth/callback?state=xyz") as handler: - _CallbackHandler.do_GET(handler) - - assert handler.response_code == 400 - assert handler.server.auth_code is None - - @pytest.mark.unit - def test_escapes_html_in_error(self): - """Handler should escape HTML in error messages to prevent XSS.""" - path = "/auth/callback?error=