diff --git a/backend/app/controller/codex_controller.py b/backend/app/controller/codex_controller.py new file mode 100644 index 000000000..67147c684 --- /dev/null +++ b/backend/app/controller/codex_controller.py @@ -0,0 +1,234 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +import logging +import time + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.utils.codex_oauth import CodexOAuthManager +from app.utils.oauth_state_manager import oauth_state_manager + + +class CodexTokenRequest(BaseModel): + r"""Request model for saving Codex/OpenAI API token.""" + + access_token: str + expires_in: int | None = None + + +logger = logging.getLogger("codex_controller") +router = APIRouter() + + +@router.post("/codex/connect", name="connect codex") +async def connect_codex(): + r"""Connect to Codex/OpenAI via OAuth PKCE flow. + + Initiates or completes the Codex OAuth authorization flow + to obtain an OpenAI API key. + + Returns: + Connection result with access token and provider info + """ + try: + if CodexOAuthManager.is_authenticated(): + if CodexOAuthManager.is_token_expired(): + # Try refreshing first + if CodexOAuthManager.refresh_token_if_needed(): + return { + "success": True, + "message": "Codex token refreshed successfully", + "toolkit_name": "CodexOAuthManager", + "access_token": CodexOAuthManager.get_access_token(), + "provider_name": "openai", + "endpoint_url": "https://api.openai.com/v1", + } + # Refresh failed, start new auth + logger.info( + "Codex token expired and refresh failed, starting re-auth" + ) + CodexOAuthManager.start_background_auth() + return { + "success": False, + "status": "authorizing", + "message": "Token expired. Browser should" + " open for re-authorization.", + "toolkit_name": "CodexOAuthManager", + "requires_auth": True, + } + + return { + "success": True, + "message": "Codex/OpenAI is already authenticated", + "toolkit_name": "CodexOAuthManager", + "access_token": CodexOAuthManager.get_access_token(), + "provider_name": "openai", + "endpoint_url": "https://api.openai.com/v1", + } + else: + logger.info("No Codex credentials found, starting OAuth flow") + CodexOAuthManager.start_background_auth() + return { + "success": False, + "status": "authorizing", + "message": "Authorization required. Browser" + " should open automatically.", + "toolkit_name": "CodexOAuthManager", + "requires_auth": True, + } + except Exception as e: + logger.error(f"Failed to connect Codex: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to connect Codex: {str(e)}" + ) + + +@router.post("/codex/disconnect", name="disconnect codex") +async def disconnect_codex(): + r"""Disconnect Codex/OpenAI and clean up authentication data. + + Cancels any active OAuth flow and clears stored tokens. + + Returns: + Disconnection result + """ + try: + # Cancel any active OAuth flow + state = oauth_state_manager.get_state("codex") + if state and state.status in ["pending", "authorizing"]: + state.cancel() + if hasattr(state, "server") and state.server: + try: + state.server.shutdown() + except Exception: + pass + oauth_state_manager._states.pop("codex", None) + + success = CodexOAuthManager.clear_token() + + if success: + return { + "success": True, + "message": ( + "Successfully disconnected Codex" + " and cleaned up" + " authentication tokens" + ), + } + else: + return { + "success": True, + "message": "Disconnected Codex (no tokens found to clean up)", + } + except Exception as e: + logger.error(f"Failed to disconnect Codex: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to disconnect Codex: {str(e)}" + ) + + +@router.post("/codex/save-token", name="save codex token") +async def save_codex_token(token_request: CodexTokenRequest): + r"""Save Codex/OpenAI API token (manual API key entry fallback). + + Args: + token_request: Token data containing access_token + and optionally expires_in + + Returns: + Save result + """ + try: + token_data = token_request.model_dump(exclude_none=True) + token_data["manual"] = True + + success = CodexOAuthManager.save_token(token_data) + + if success: + return { + "success": True, + "message": "Codex token saved successfully", + } + else: + raise HTTPException( + status_code=500, detail="Failed to save Codex token" + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to save Codex token: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to save token: {str(e)}" + ) + + +@router.get("/codex/status", name="get codex status") +async def get_codex_status(): + r"""Get current Codex/OpenAI authentication status and token info. + + Returns: + Status information including authentication state and token expiry + """ + try: + is_authenticated = CodexOAuthManager.is_authenticated() + + if not is_authenticated: + return { + "authenticated": False, + "status": "not_configured", + "message": "Codex not configured. OAuth or API key required.", + } + + token_info = CodexOAuthManager.get_token_info() + is_expired = CodexOAuthManager.is_token_expired() + is_expiring_soon = CodexOAuthManager.is_token_expiring_soon() + + result = { + "authenticated": True, + "status": "expired" + if is_expired + else ("expiring_soon" if is_expiring_soon else "valid"), + } + + if token_info: + if token_info.get("expires_at"): + current_time = int(time.time()) + expires_at = token_info["expires_at"] + seconds_remaining = max(0, expires_at - current_time) + result["expires_at"] = expires_at + result["seconds_remaining"] = seconds_remaining + + if token_info.get("saved_at"): + result["saved_at"] = token_info["saved_at"] + + if token_info.get("manual"): + result["manual"] = True + + if is_expired: + result["message"] = "Token has expired. Please re-authenticate." + elif is_expiring_soon: + result["message"] = ( + "Token is expiring soon. Consider re-authenticating." + ) + else: + result["message"] = "Codex/OpenAI is connected and token is valid." + + return result + except Exception as e: + logger.error(f"Failed to get Codex status: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to get status: {str(e)}" + ) diff --git a/backend/app/router.py b/backend/app/router.py index f257dde08..22c23369c 100644 --- a/backend/app/router.py +++ b/backend/app/router.py @@ -23,6 +23,7 @@ from fastapi import FastAPI from app.controller import ( chat_controller, + codex_controller, health_controller, model_controller, task_controller, @@ -71,6 +72,11 @@ def register_routers(app: FastAPI, prefix: str = "") -> None: "tags": ["tool"], "description": "Tool installation and management", }, + { + "router": codex_controller.router, + "tags": ["codex"], + "description": "Codex OAuth provider authentication", + }, ] for config in routers_config: diff --git a/backend/app/utils/codex_oauth.py b/backend/app/utils/codex_oauth.py new file mode 100644 index 000000000..ed2e96120 --- /dev/null +++ b/backend/app/utils/codex_oauth.py @@ -0,0 +1,548 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +r"""OpenAI Codex OAuth manager. + +Handles Authorization Code + PKCE flow using Codex CLI's public client_id. +The resulting access token is stored in an encrypted file and used +independently of the OPENAI_API_KEY environment variable. +""" + +import base64 +import getpass +import hashlib +import json +import logging +import os +import platform +import secrets +import socket +import stat +import threading +import time +import webbrowser +from html import escape as html_escape +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import parse_qs, urlencode, urlparse + +import requests +from cryptography.fernet import Fernet, InvalidToken +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from filelock import FileLock + +from app.utils.oauth_state_manager import oauth_state_manager + +logger = logging.getLogger("codex_oauth") + +# OpenAI / Codex OAuth constants +# Fixed public client_id from the Codex CLI (not a secret). +CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +CODEX_AUTH_URL = "https://auth.openai.com/oauth/authorize" +CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token" +# Fixed callback port used by Codex CLI +CODEX_CALLBACK_PORT = 1455 + +# Token storage path +CODEX_TOKEN_DIR = os.path.join( + os.path.expanduser("~"), ".eigent", "tokens", "codex" +) +CODEX_TOKEN_PATH = os.path.join(CODEX_TOKEN_DIR, "codex_token.enc") + +# Token lifetime defaults (seconds) +CODEX_TOKEN_DEFAULT_LIFETIME = 3600 # 1 hour +CODEX_TOKEN_REFRESH_THRESHOLD = 300 # 5 minutes before expiry + + +def _get_machine_identifier() -> bytes: + r"""Get a machine-specific identifier for key derivation. + + Combines multiple machine-specific values to create a stable identifier + that is unique to this machine but consistent across restarts. + + Returns: + Machine identifier as bytes. + """ + components = [ + getpass.getuser(), + socket.gethostname(), + platform.node(), + # Add home directory path for additional uniqueness + os.path.expanduser("~"), + ] + + # Try to get machine-id on Linux + machine_id_paths = [ + "/etc/machine-id", + "/var/lib/dbus/machine-id", + ] + for path in machine_id_paths: + try: + with open(path) as f: + components.append(f.read().strip()) + break + except (FileNotFoundError, PermissionError): + continue + + return "|".join(components).encode("utf-8") + + +def _derive_encryption_key() -> bytes: + r"""Derive an encryption key from machine-specific identifiers. + + Uses PBKDF2 to derive a Fernet-compatible key from machine identifiers. + This ties the encryption to the specific machine without storing a key file. + + Returns: + The Fernet encryption key as bytes. + """ + # Fixed salt for this application (not secret, just ensures uniqueness) + # The security comes from the machine-specific identifier + app_salt = b"eigent-codex-token-encryption-v1" + + machine_id = _get_machine_identifier() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=app_salt, + iterations=100_000, + ) + + # Derive a 32-byte key and encode it as base64 for Fernet + derived_key = kdf.derive(machine_id) + return base64.urlsafe_b64encode(derived_key) + + +def _encrypt_token_data(token_data: dict) -> bytes: + r"""Encrypt token data using Fernet symmetric encryption. + + Args: + token_data: Dictionary containing token information. + + Returns: + Encrypted bytes. + """ + key = _derive_encryption_key() + fernet = Fernet(key) + json_bytes = json.dumps(token_data).encode("utf-8") + return fernet.encrypt(json_bytes) + + +def _decrypt_token_data(encrypted_data: bytes) -> dict | None: + r"""Decrypt token data. + + Args: + encrypted_data: Encrypted bytes from file. + + Returns: + Decrypted token dictionary, or None if decryption fails. + """ + try: + key = _derive_encryption_key() + fernet = Fernet(key) + decrypted = fernet.decrypt(encrypted_data) + return json.loads(decrypted.decode("utf-8")) + except (InvalidToken, json.JSONDecodeError) as e: + logger.warning("Failed to decrypt token data: %s", e) + return None + + +def _generate_pkce_pair() -> tuple[str, str]: + r"""Generate a PKCE code_verifier and S256 code_challenge. + + Returns: + Tuple of (code_verifier, code_challenge). + """ + code_verifier = secrets.token_urlsafe(64) + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = ( + base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + ) + return code_verifier, code_challenge + + +class _CallbackHandler(BaseHTTPRequestHandler): + r"""HTTP handler that captures the OAuth callback code.""" + + def do_GET(self): + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + + # Validate state parameter to prevent CSRF attacks + received_state = params.get("state", [None])[0] + expected_state = getattr(self.server, "expected_state", None) + + if expected_state and received_state != expected_state: + self.server.auth_error = "state_mismatch: Invalid state parameter" + self.send_response(400) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"

Authorization failed

" + b"

Invalid state parameter. Possible CSRF attack.

" + b"" + ) + return + + if "code" in params: + self.server.auth_code = params["code"][0] + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"" + b"

Authorization successful!

" + b"

You can close this window and return to Eigent.

" + b"" + ) + elif "error" in params: + error = params.get("error", ["unknown"])[0] + desc = params.get("error_description", [""])[0] + self.server.auth_error = f"{error}: {desc}" + self.send_response(400) + self.send_header("Content-Type", "text/html") + self.end_headers() + # Escape HTML to prevent XSS from query parameters + self.wfile.write( + f"

Authorization failed

" + f"

{html_escape(error)}: {html_escape(desc)}

" + f"".encode() + ) + else: + self.send_response(400) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"

" + b"Missing authorization code" + b"

" + ) + + def log_message(self, format, *args): + logger.debug("Codex callback server: %s", format % args) + + +class CodexOAuthManager: + r"""Manages OpenAI Codex OAuth token lifecycle.""" + + @staticmethod + def _token_path() -> str: + return CODEX_TOKEN_PATH + + @classmethod + def save_token(cls, token_data: dict) -> bool: + r"""Save token data to disk with encryption. + + Args: + token_data: Dictionary with at least ``access_token``. + + Returns: + True on success. + """ + path = cls._token_path() + token_data = token_data.copy() + try: + if "saved_at" not in token_data: + token_data["saved_at"] = int(time.time()) + + # Compute absolute expiry from the relative expires_in value + # (if present), then discard expires_in so we only store the + # absolute timestamp. + if "expires_at" not in token_data: + expires_in = token_data.pop( + "expires_in", CODEX_TOKEN_DEFAULT_LIFETIME + ) + token_data["expires_at"] = token_data["saved_at"] + expires_in + else: + token_data.pop("expires_in", None) + + os.makedirs(os.path.dirname(path), exist_ok=True) + lock = FileLock(path + ".lock") + + # Encrypt token data before saving + encrypted_data = _encrypt_token_data(token_data) + + with lock, open(path, "wb") as f: + f.write(encrypted_data) + + # Set restrictive permissions on token file + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) + + logger.info("Saved encrypted Codex token to %s", path) + + return True + except Exception as e: + logger.error("Failed to save Codex token: %s", e) + return False + + @classmethod + def load_token(cls) -> dict | None: + r"""Load and decrypt token data from disk.""" + path = cls._token_path() + if os.path.exists(path): + try: + lock = FileLock(path + ".lock") + with lock, open(path, "rb") as f: + encrypted_data = f.read() + return _decrypt_token_data(encrypted_data) + except Exception as e: + logger.warning("Failed to load token: %s", e) + return None + + @classmethod + def clear_token(cls) -> bool: + r"""Remove stored token file.""" + path = cls._token_path() + try: + if os.path.exists(path): + os.remove(path) + logger.info("Removed Codex token file: %s", path) + + token_dir = os.path.dirname(path) + if os.path.exists(token_dir) and not os.listdir(token_dir): + os.rmdir(token_dir) + + return True + except Exception as e: + logger.error("Failed to clear Codex token: %s", e) + return False + + @classmethod + def is_authenticated(cls) -> bool: + r"""Return True if a Codex OAuth token is available.""" + token = cls.load_token() + return bool(token and token.get("access_token")) + + @classmethod + def is_token_expired(cls) -> bool: + r"""Return True if the stored token has expired.""" + token = cls.load_token() + if not token: + return False + expires_at = token.get("expires_at") + if not expires_at: + return False + return int(time.time()) >= expires_at + + @classmethod + def is_token_expiring_soon(cls) -> bool: + r"""Return True if the token expires within the refresh threshold.""" + token = cls.load_token() + if not token: + return False + expires_at = token.get("expires_at") + if not expires_at: + return False + return (expires_at - int(time.time())) < CODEX_TOKEN_REFRESH_THRESHOLD + + @classmethod + def get_access_token(cls) -> str | None: + r"""Return the current Codex OAuth access token.""" + token = cls.load_token() + if token and token.get("access_token"): + return token["access_token"] + return None + + @classmethod + def get_token_info(cls) -> dict | None: + r"""Return stored token metadata.""" + return cls.load_token() + + @classmethod + def refresh_token_if_needed(cls) -> bool: + r"""Attempt to refresh the token if it has a refresh_token. + + Returns: + True if refreshed or not needed, False on failure. + """ + token = cls.load_token() + if not token: + return False + + if not cls.is_token_expiring_soon(): + return True + + refresh_token = token.get("refresh_token") + if not refresh_token: + return False + + try: + resp = requests.post( + CODEX_TOKEN_URL, + data={ + "grant_type": "refresh_token", + "client_id": CODEX_CLIENT_ID, + "refresh_token": refresh_token, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30, + ) + resp.raise_for_status() + new_token = resp.json() + + # Merge with existing data + token.update( + { + "access_token": new_token["access_token"], + "expires_in": new_token.get( + "expires_in", CODEX_TOKEN_DEFAULT_LIFETIME + ), + "saved_at": int(time.time()), + } + ) + if new_token.get("refresh_token"): + token["refresh_token"] = new_token["refresh_token"] + token.pop("expires_at", None) + + return cls.save_token(token) + except Exception as e: + logger.error("Failed to refresh Codex token: %s", e) + return False + + # ------------------------------------------------------------------ + # Background OAuth flow + # ------------------------------------------------------------------ + + @classmethod + def start_background_auth(cls) -> str: + r"""Launch the PKCE OAuth flow in a background thread. + + Returns: + ``"authorizing"`` immediately. + """ + # Cancel any existing flow + old_state = oauth_state_manager.get_state("codex") + if old_state and old_state.status in ["pending", "authorizing"]: + old_state.cancel() + if hasattr(old_state, "server") and old_state.server: + try: + old_state.server.shutdown() + except Exception: + pass + + state = oauth_state_manager.create_state("codex") + + def _auth_flow(): + try: + state.status = "authorizing" + oauth_state_manager.update_status("codex", "authorizing") + + code_verifier, code_challenge = _generate_pkce_pair() + + # Generate state parameter to prevent CSRF attacks + oauth_state = secrets.token_urlsafe(32) + + # Start localhost callback server on fixed port 1455 (Codex standard) + server = HTTPServer( + ("127.0.0.1", CODEX_CALLBACK_PORT), _CallbackHandler + ) + server.auth_code = None + server.auth_error = None + server.expected_state = oauth_state + state.server = server + + redirect_uri = ( + f"http://localhost:{CODEX_CALLBACK_PORT}/auth/callback" + ) + + params = urlencode( + { + "response_type": "code", + "client_id": CODEX_CLIENT_ID, + "redirect_uri": redirect_uri, + "scope": "openid profile email offline_access", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": oauth_state, + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + } + ) + auth_url = f"{CODEX_AUTH_URL}?{params}" + + if state.is_cancelled(): + server.server_close() + return + + logger.info( + "Opening browser for Codex OAuth on port %d", + CODEX_CALLBACK_PORT, + ) + webbrowser.open(auth_url) + + # Wait for the callback (single request) + server.handle_request() + + if state.is_cancelled(): + server.server_close() + return + + if server.auth_error: + raise ValueError(server.auth_error) + + if not server.auth_code: + raise ValueError("No authorization code received") + + auth_code = server.auth_code + server.server_close() + + # Exchange code for token + token_resp = requests.post( + CODEX_TOKEN_URL, + data={ + "grant_type": "authorization_code", + "client_id": CODEX_CLIENT_ID, + "code": auth_code, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded" + }, + timeout=30, + ) + token_resp.raise_for_status() + token_data = token_resp.json() + + if state.is_cancelled(): + return + + cls.save_token(token_data) + + oauth_state_manager.update_status( + "codex", "success", result=token_data + ) + logger.info("Codex OAuth authorization successful") + + except Exception as e: + if state.is_cancelled(): + oauth_state_manager.update_status("codex", "cancelled") + else: + logger.error("Codex OAuth failed: %s", e) + oauth_state_manager.update_status( + "codex", "failed", error=str(e) + ) + finally: + state.server = None + + thread = threading.Thread( + target=_auth_flow, + daemon=True, + name=f"Codex-OAuth-{state.started_at.timestamp()}", + ) + state.thread = thread + thread.start() + + logger.info("Started background Codex OAuth authorization") + return "authorizing" diff --git a/backend/tests/app/utils/test_codex_oauth.py b/backend/tests/app/utils/test_codex_oauth.py new file mode 100644 index 000000000..004283c98 --- /dev/null +++ b/backend/tests/app/utils/test_codex_oauth.py @@ -0,0 +1,875 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +r"""Tests for the Codex OAuth manager.""" + +import io +import os +import stat +import tempfile +import time +from collections.abc import Generator +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + +from app.utils.codex_oauth import ( + CODEX_CLIENT_ID, + CODEX_TOKEN_DEFAULT_LIFETIME, + CODEX_TOKEN_REFRESH_THRESHOLD, + CodexOAuthManager, + _CallbackHandler, + _decrypt_token_data, + _derive_encryption_key, + _encrypt_token_data, + _generate_pkce_pair, + _get_machine_identifier, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def temp_token_path() -> Generator[str, None, None]: + r"""Create a temporary token path and patch CodexOAuthManager to use it.""" + with tempfile.TemporaryDirectory() as temp_dir: + token_path = os.path.join(temp_dir, "codex_token.enc") + with patch.object( + CodexOAuthManager, "_token_path", return_value=token_path + ): + yield token_path + + +@pytest.fixture +def clean_env() -> Generator[None, None, None]: + r"""Ensure OPENAI_API_KEY is cleaned up before and after test.""" + original = os.environ.pop("OPENAI_API_KEY", None) + yield + os.environ.pop("OPENAI_API_KEY", None) + if original is not None: + os.environ["OPENAI_API_KEY"] = original + + +@contextmanager +def mock_callback_request(path: str, expected_state: str | None = None): + r"""Create a mock HTTP request for _CallbackHandler testing.""" + handler = MagicMock(spec=_CallbackHandler) + handler.path = path + handler.wfile = io.BytesIO() + + # Track response + handler.response_code = None + handler.headers_sent = {} + + def send_response(code): + handler.response_code = code + + def send_header(name, value): + handler.headers_sent[name] = value + + def end_headers(): + pass + + handler.send_response = send_response + handler.send_header = send_header + handler.end_headers = end_headers + + # Create mock server + handler.server = MagicMock() + handler.server.auth_code = None + handler.server.auth_error = None + handler.server.expected_state = expected_state + + yield handler + + +# --------------------------------------------------------------------------- +# PKCE Generation Tests +# --------------------------------------------------------------------------- + + +class TestPKCEGeneration: + r"""Tests for PKCE code verifier and challenge generation.""" + + @pytest.mark.unit + def test_returns_tuple_of_strings(self): + """PKCE pair should be a tuple of two strings.""" + verifier, challenge = _generate_pkce_pair() + + assert isinstance(verifier, str) + assert isinstance(challenge, str) + + @pytest.mark.unit + def test_verifier_meets_minimum_length(self): + """Code verifier should be at least 43 characters (RFC 7636).""" + verifier, _ = _generate_pkce_pair() + + assert len(verifier) >= 43 + + @pytest.mark.unit + def test_challenge_is_base64url_without_padding(self): + """Code challenge should be valid base64url without padding.""" + _, challenge = _generate_pkce_pair() + + # Base64url should not contain +, /, or = + assert "+" not in challenge + assert "/" not in challenge + assert "=" not in challenge + + @pytest.mark.unit + def test_generates_unique_values(self): + """Each call should generate cryptographically unique values.""" + pairs = [_generate_pkce_pair() for _ in range(10)] + verifiers = [p[0] for p in pairs] + challenges = [p[1] for p in pairs] + + assert len(set(verifiers)) == 10 + assert len(set(challenges)) == 10 + + +# --------------------------------------------------------------------------- +# Machine Identifier Tests +# --------------------------------------------------------------------------- + + +class TestMachineIdentifier: + r"""Tests for machine identifier generation.""" + + @pytest.mark.unit + def test_returns_bytes(self): + """Machine identifier should be bytes.""" + identifier = _get_machine_identifier() + + assert isinstance(identifier, bytes) + + @pytest.mark.unit + def test_is_deterministic(self): + """Machine identifier should be consistent across calls.""" + assert _get_machine_identifier() == _get_machine_identifier() + + @pytest.mark.unit + def test_contains_multiple_components(self): + """Identifier should contain pipe-separated components.""" + identifier = _get_machine_identifier() + decoded = identifier.decode("utf-8") + components = decoded.split("|") + + # Should have at least username, hostname, platform.node, home dir + assert len(components) >= 4 + + +# --------------------------------------------------------------------------- +# Encryption Tests +# --------------------------------------------------------------------------- + + +class TestEncryption: + r"""Tests for token encryption and decryption.""" + + @pytest.mark.unit + def test_derive_key_returns_fernet_compatible_key(self): + """Derived key should be 44 bytes (Fernet format).""" + key = _derive_encryption_key() + + assert isinstance(key, bytes) + assert len(key) == 44 + + @pytest.mark.unit + def test_derive_key_is_deterministic(self): + """Derived key should be consistent for same machine.""" + assert _derive_encryption_key() == _derive_encryption_key() + + @pytest.mark.unit + def test_encrypt_decrypt_roundtrip(self): + """Data should survive encryption and decryption.""" + original = { + "access_token": "sk-test-token-123", + "refresh_token": "rt-refresh-456", + "expires_at": 1234567890, + "scope": "openai.api.read", + } + + encrypted = _encrypt_token_data(original) + decrypted = _decrypt_token_data(encrypted) + + assert decrypted == original + + @pytest.mark.unit + def test_encrypt_returns_bytes(self): + """Encrypted output should be bytes.""" + encrypted = _encrypt_token_data({"access_token": "test"}) + + assert isinstance(encrypted, bytes) + assert len(encrypted) > 0 + + @pytest.mark.unit + def test_encrypted_data_differs_from_input(self): + """Encrypted data should not contain plaintext.""" + token = "my_secret_token" + encrypted = _encrypt_token_data({"access_token": token}) + + assert token.encode() not in encrypted + + @pytest.mark.unit + def test_decrypt_invalid_data_returns_none(self): + """Decrypting garbage data should return None, not raise.""" + assert _decrypt_token_data(b"not-valid-fernet-data") is None + + @pytest.mark.unit + def test_decrypt_corrupted_data_returns_none(self): + """Decrypting tampered data should return None.""" + encrypted = _encrypt_token_data({"access_token": "test"}) + corrupted = encrypted[:-10] + b"tampered!!" + + assert _decrypt_token_data(corrupted) is None + + @pytest.mark.unit + def test_decrypt_empty_data_returns_none(self): + """Decrypting empty data should return None.""" + assert _decrypt_token_data(b"") is None + + +# --------------------------------------------------------------------------- +# Callback Handler Tests +# --------------------------------------------------------------------------- + + +class TestCallbackHandler: + r"""Tests for OAuth callback HTTP handler.""" + + @pytest.mark.unit + def test_captures_authorization_code(self): + """Handler should capture auth code from callback URL.""" + path = "/auth/callback?code=auth_code_123&state=valid_state" + with mock_callback_request( + path, expected_state="valid_state" + ) as handler: + _CallbackHandler.do_GET(handler) + + assert handler.server.auth_code == "auth_code_123" + assert handler.response_code == 200 + + @pytest.mark.unit + def test_captures_error_response(self): + """Handler should capture error from callback URL.""" + path = "/auth/callback?error=access_denied&error_description=User%20denied" + with mock_callback_request(path) as handler: + _CallbackHandler.do_GET(handler) + + assert handler.server.auth_error == "access_denied: User denied" + assert handler.response_code == 400 + + @pytest.mark.unit + def test_handles_missing_code(self): + """Handler should return 400 when code is missing.""" + with mock_callback_request("/auth/callback?state=xyz") as handler: + _CallbackHandler.do_GET(handler) + + assert handler.response_code == 400 + assert handler.server.auth_code is None + + @pytest.mark.unit + def test_escapes_html_in_error(self): + """Handler should escape HTML in error messages to prevent XSS.""" + path = "/auth/callback?error=