agent-zero/plugins/_oauth/helpers/codex.py
Alessandro f67564a8ae Add Codex/ChatGPT account OAuth provider
Create a generic OAuth Connections plugin with Codex/ChatGPT Account as the first provider, using OpenAI's device-code flow to persist Codex-compatible account tokens.

Expose a loopback OpenAI-compatible wrapper for models, responses, and chat completions, and point LiteLLM at the container-local Agent Zero origin.

Add a dummy API-key extension and focused tests so the account-backed provider appears configured without requiring a user-entered key.

docs: add Codex plan OAuth callout

Highlight that Agent Zero can use an existing OpenAI Codex plan through the new OAuth flow.

Add the account-backed LLM plans image and surface the section from the README navigation, while pointing toward future Gemini CLI and Claude Code integrations.

Handle Codex account SSE chat chunks

Teach the Codex/ChatGPT account bridge to extract text from OpenAI-style SSE chat completion deltas and fall back to a normal output_text response when upstream only streams chunks.

Strip user-supplied stream kwargs before LiteLLM calls so Agent Zero owns streaming mode and custom parameters cannot pass stream twice.

Add targeted tests for streamed delta extraction and reconstructed responses.

update README.md with LLM plans mention
2026-04-28 16:14:53 +02:00

859 lines
27 KiB
Python

from __future__ import annotations
import base64
import hashlib
import json
import os
import secrets
import subprocess
import time
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Iterable
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
import requests
from helpers import files
from plugins._oauth.helpers.config import codex_config
AUTH_FILENAME = "auth.json"
ACCESS_EXPIRY_MARGIN = timedelta(minutes=5)
REFRESH_INTERVAL = timedelta(minutes=55)
FALLBACK_CODEX_VERSION = "0.124.0"
OAUTH_ERROR_KEYS = {"error", "error_description"}
DEVICE_CODE_TIMEOUT_SECONDS = 15 * 60
@dataclass(frozen=True)
class PkcePair:
verifier: str
challenge: str
@dataclass(frozen=True)
class EffectiveAuth:
access_token: str
account_id: str
id_token: str = ""
refresh_token: str = ""
source_path: str = ""
last_refresh: str = ""
def generate_pkce() -> PkcePair:
verifier = _base64url(secrets.token_bytes(64))
challenge = _base64url(hashlib.sha256(verifier.encode("utf-8")).digest())
return PkcePair(verifier=verifier, challenge=challenge)
def generate_state() -> str:
return _base64url(secrets.token_bytes(32))
def build_authorize_url(redirect_uri: str, state: str, pkce: PkcePair) -> str:
cfg = codex_config()
query = {
"response_type": "code",
"client_id": cfg["client_id"],
"redirect_uri": redirect_uri,
"scope": " ".join(cfg["scopes"]),
"code_challenge": pkce.challenge,
"code_challenge_method": "S256",
"id_token_add_organizations": "true",
"codex_cli_simplified_flow": "true",
"state": state,
"originator": "codex_cli_rs",
}
if cfg["forced_workspace_id"]:
query["allowed_workspace_id"] = cfg["forced_workspace_id"]
return f'{cfg["issuer"]}/oauth/authorize?{urlencode(query)}'
def exchange_code_for_tokens(
code: str,
redirect_uri: str,
verifier: str,
) -> dict[str, str]:
cfg = codex_config()
response = requests.post(
cfg["token_url"],
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"client_id": cfg["client_id"],
"code_verifier": verifier,
},
timeout=30,
)
if not response.ok:
raise RuntimeError(_token_error_message(response))
payload = response.json()
if not isinstance(payload, dict):
raise RuntimeError("OAuth token endpoint returned a malformed response.")
tokens = {
"id_token": str(payload.get("id_token") or ""),
"access_token": str(payload.get("access_token") or ""),
"refresh_token": str(payload.get("refresh_token") or ""),
}
missing = [key for key, value in tokens.items() if not value]
if missing:
raise RuntimeError(f"OAuth token response is missing: {', '.join(missing)}")
return tokens
def obtain_api_key(id_token: str) -> str:
cfg = codex_config()
response = requests.post(
cfg["token_url"],
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"client_id": cfg["client_id"],
"requested_token": "openai-api-key",
"subject_token": id_token,
"subject_token_type": "urn:ietf:params:oauth:token-type:id_token",
},
timeout=30,
)
if not response.ok:
raise RuntimeError(f"API-key token exchange failed with status {response.status_code}.")
payload = response.json()
if not isinstance(payload, dict) or not payload.get("access_token"):
raise RuntimeError("API-key token exchange returned a malformed response.")
return str(payload["access_token"])
def complete_login(code: str, redirect_uri: str, verifier: str) -> EffectiveAuth:
tokens = exchange_code_for_tokens(code, redirect_uri, verifier)
return persist_exchanged_tokens(tokens)
def persist_exchanged_tokens(tokens: dict[str, str]) -> EffectiveAuth:
id_token = tokens["id_token"]
account_id = derive_account_id(id_token)
if not account_id:
raise RuntimeError("OAuth ID token did not include a ChatGPT account id.")
cfg = codex_config()
if cfg["forced_workspace_id"] and account_id != cfg["forced_workspace_id"]:
raise RuntimeError(
f'Login is restricted to workspace id {cfg["forced_workspace_id"]}.'
)
try:
api_key = obtain_api_key(id_token)
except Exception:
api_key = ""
auth_data = {
"auth_mode": "chatgpt",
"OPENAI_API_KEY": api_key or None,
"tokens": {
"id_token": id_token,
"access_token": tokens["access_token"],
"refresh_token": tokens["refresh_token"],
"account_id": account_id,
},
"last_refresh": utc_now_iso(),
}
path = resolve_auth_write_path()
write_auth_file(path, auth_data)
return load_auth(ensure_fresh=False)
def request_device_code() -> dict[str, Any]:
cfg = codex_config()
base_url = cfg["issuer"].rstrip("/")
response = requests.post(
f"{base_url}/api/accounts/deviceauth/usercode",
headers={"Content-Type": "application/json"},
json={"client_id": cfg["client_id"]},
timeout=30,
)
if not response.ok:
raise RuntimeError(_token_error_message(response))
payload = response.json()
if not isinstance(payload, dict):
raise RuntimeError("Device authorization returned a malformed response.")
device_auth_id = _string(payload.get("device_auth_id"))
user_code = _string(payload.get("user_code") or payload.get("usercode"))
if not device_auth_id or not user_code:
raise RuntimeError("Device authorization response did not include a code.")
interval = _safe_int(payload.get("interval"), 5)
expires_at = _device_expires_at(payload.get("expires_at"))
return {
"device_auth_id": device_auth_id,
"user_code": user_code,
"interval": interval,
"expires_at": expires_at,
"verification_url": f"{base_url}/codex/device",
}
def poll_device_authorization(device_auth_id: str, user_code: str) -> dict[str, Any]:
cfg = codex_config()
base_url = cfg["issuer"].rstrip("/")
response = requests.post(
f"{base_url}/api/accounts/deviceauth/token",
headers={"Content-Type": "application/json"},
json={"device_auth_id": device_auth_id, "user_code": user_code},
timeout=30,
)
if response.status_code in {403, 404}:
return {"completed": False}
if not response.ok:
raise RuntimeError(_token_error_message(response))
payload = response.json()
if not isinstance(payload, dict):
raise RuntimeError("Device authorization token response was malformed.")
authorization_code = _string(payload.get("authorization_code"))
verifier = _string(payload.get("code_verifier"))
if not authorization_code or not verifier:
raise RuntimeError("Device authorization response was missing token exchange data.")
tokens = exchange_code_for_tokens(
authorization_code,
f"{base_url}/deviceauth/callback",
verifier,
)
auth = persist_exchanged_tokens(tokens)
return {"completed": True, "account_id": auth.account_id}
def load_auth(*, ensure_fresh: bool = True) -> EffectiveAuth:
path, data = read_auth_file()
tokens = data.get("tokens") if isinstance(data, dict) else {}
tokens = tokens if isinstance(tokens, dict) else {}
access_token = _string(tokens.get("access_token"))
id_token = _string(tokens.get("id_token"))
refresh_token = _string(tokens.get("refresh_token"))
account_id = _string(tokens.get("account_id")) or derive_account_id(id_token)
last_refresh = _string(data.get("last_refresh")) if isinstance(data, dict) else ""
if ensure_fresh and refresh_token and should_refresh(access_token, last_refresh):
refreshed = refresh_tokens(refresh_token)
access_token = refreshed.get("access_token") or access_token
id_token = refreshed.get("id_token") or id_token
refresh_token = refreshed.get("refresh_token") or refresh_token
account_id = derive_account_id(id_token) or account_id
last_refresh = utc_now_iso()
data["tokens"] = {
"id_token": id_token,
"access_token": access_token,
"refresh_token": refresh_token,
"account_id": account_id,
}
data["last_refresh"] = last_refresh
write_auth_file(path, data)
if not access_token:
raise RuntimeError("Codex/ChatGPT account access token not found. Connect the account first.")
if not account_id:
raise RuntimeError("Codex/ChatGPT account id not found. Connect the account again.")
return EffectiveAuth(
access_token=access_token,
account_id=account_id,
id_token=id_token,
refresh_token=refresh_token,
source_path=str(path),
last_refresh=last_refresh,
)
def status() -> dict[str, Any]:
candidates = resolve_auth_file_candidates()
existing = [str(path) for path in candidates if path.is_file()]
result: dict[str, Any] = {
"connected": False,
"auth_file_path": str(resolve_auth_write_path()),
"discovered_auth_files": existing,
}
try:
auth = load_auth(ensure_fresh=False)
except Exception as exc:
result["message"] = str(exc)
return result
id_claims = parse_jwt_claims(auth.id_token)
access_claims = parse_jwt_claims(auth.access_token)
auth_claims = _auth_claims(id_claims)
result.update(
{
"connected": True,
"auth_file_path": auth.source_path,
"account_id": auth.account_id,
"email": id_claims.get("email")
or _record(id_claims.get("https://api.openai.com/profile")).get("email"),
"plan_type": auth_claims.get("chatgpt_plan_type"),
"user_id": auth_claims.get("chatgpt_user_id") or auth_claims.get("user_id"),
"access_expires_at": _jwt_expiration_iso(access_claims),
"last_refresh": auth.last_refresh,
}
)
return result
def refresh_tokens(refresh_token: str) -> dict[str, str]:
cfg = codex_config()
response = requests.post(
cfg["token_url"],
headers={"Content-Type": "application/json"},
json={
"client_id": cfg["client_id"],
"grant_type": "refresh_token",
"refresh_token": refresh_token,
},
timeout=30,
)
if not response.ok:
raise RuntimeError(_token_error_message(response))
payload = response.json()
if not isinstance(payload, dict):
raise RuntimeError("OAuth refresh endpoint returned a malformed response.")
return {
"id_token": _string(payload.get("id_token")),
"access_token": _string(payload.get("access_token")),
"refresh_token": _string(payload.get("refresh_token")) or refresh_token,
}
def should_refresh(access_token: str, last_refresh: str) -> bool:
if not access_token:
return True
claims = parse_jwt_claims(access_token)
exp = claims.get("exp")
if isinstance(exp, (int, float)):
expires_at = datetime.fromtimestamp(float(exp), tz=timezone.utc)
if expires_at <= datetime.now(timezone.utc) + ACCESS_EXPIRY_MARGIN:
return True
refreshed_at = parse_iso(last_refresh)
if refreshed_at is not None:
return refreshed_at <= datetime.now(timezone.utc) - REFRESH_INTERVAL
return False
def request_codex(
path: str,
*,
method: str = "GET",
headers: dict[str, str] | None = None,
body: bytes | str | None = None,
stream: bool = False,
params: dict[str, str] | None = None,
) -> requests.Response:
cfg = codex_config()
auth = load_auth()
target = build_upstream_url(path, cfg["upstream_base_url"])
request_headers = sanitize_forward_headers(headers or {})
request_headers.update(
{
"Authorization": f"Bearer {auth.access_token}",
"chatgpt-account-id": auth.account_id,
"OpenAI-Beta": "responses=experimental",
}
)
return requests.request(
method,
target,
headers=request_headers,
data=body,
params=params,
timeout=max(5, cfg["request_timeout_seconds"]),
stream=stream,
)
def fetch_models() -> list[str]:
cfg = codex_config()
configured = cfg["models"]
if configured:
return configured
response = request_codex(
"/models",
params={"client_version": resolve_codex_version()},
)
if not response.ok:
raise RuntimeError(upstream_error_message(response, "Failed to load Codex models."))
payload = response.json()
raw_models = payload.get("models") if isinstance(payload, dict) else None
if not isinstance(raw_models, list):
raise RuntimeError("Codex returned a malformed models response.")
models: list[str] = []
seen: set[str] = set()
for item in raw_models:
slug = item.get("slug") if isinstance(item, dict) else None
if isinstance(slug, str) and slug and slug not in seen:
seen.add(slug)
models.append(slug)
if not models:
raise RuntimeError("Codex returned an empty models list.")
return models
def prepare_responses_body(body: dict[str, Any], *, force_stream: bool) -> dict[str, Any]:
normalized = dict(body)
normalized.setdefault("instructions", "")
normalized.setdefault("store", False)
if force_stream:
normalized["stream"] = True
normalized.pop("max_output_tokens", None)
return normalized
def collect_completed_response(response: requests.Response) -> dict[str, Any]:
latest_response: dict[str, Any] | None = None
latest_error: Any = None
text_pieces: list[str] = []
latest_usage: dict[str, Any] | None = None
for event in iter_sse_events(response):
data = event.get("data")
if not data:
continue
try:
parsed = json.loads(data)
except json.JSONDecodeError:
continue
if not isinstance(parsed, dict):
continue
if event.get("event") == "error":
latest_error = parsed
continue
text_pieces.extend(extract_sse_text_deltas(parsed, event.get("event", "")))
usage = parsed.get("usage")
if isinstance(usage, dict):
latest_usage = usage
candidate = parsed.get("response")
if isinstance(candidate, dict):
latest_response = candidate
if latest_response is not None:
return latest_response
if text_pieces:
result: dict[str, Any] = {"output_text": "".join(text_pieces)}
if latest_usage:
result["usage"] = latest_usage
return result
suffix = f" Last error: {json.dumps(latest_error)}" if latest_error else ""
raise RuntimeError(f"No completed response found in Codex SSE stream.{suffix}")
def iter_sse_events(response: requests.Response) -> Iterable[dict[str, str]]:
buffer = ""
for chunk in response.iter_content(chunk_size=8192, decode_unicode=True):
if not chunk:
continue
buffer += chunk
while "\n\n" in buffer or "\r\n\r\n" in buffer:
sep = "\r\n\r\n" if "\r\n\r\n" in buffer else "\n\n"
block, buffer = buffer.split(sep, 1)
event = parse_sse_block(block)
if event:
yield event
event = parse_sse_block(buffer)
if event:
yield event
def parse_sse_block(block: str) -> dict[str, str]:
event: dict[str, str] = {}
data_lines: list[str] = []
for line in block.splitlines():
if line.startswith("event:"):
event["event"] = line[6:].strip()
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
if data_lines:
event["data"] = "\n".join(data_lines)
return event
def extract_sse_text_deltas(payload: dict[str, Any], event_type: str = "") -> list[str]:
pieces: list[str] = []
choices = payload.get("choices")
if isinstance(choices, list):
for choice in choices:
if not isinstance(choice, dict):
continue
delta = choice.get("delta")
if isinstance(delta, dict):
_append_text_value(pieces, delta.get("content"))
elif isinstance(delta, str):
pieces.append(delta)
message = choice.get("message")
if isinstance(message, dict):
_append_text_value(pieces, message.get("content"))
delta = payload.get("delta")
if isinstance(delta, str):
pieces.append(delta)
elif isinstance(delta, dict):
_append_text_value(pieces, delta.get("content"))
_append_text_value(pieces, delta.get("text"))
if (payload.get("type") or event_type) in {
"response.output_text.delta",
"response.text.delta",
}:
_append_text_value(pieces, payload.get("text"))
return [piece for piece in pieces if piece]
def _append_text_value(pieces: list[str], value: Any) -> None:
if isinstance(value, str):
pieces.append(value)
return
if isinstance(value, list):
for item in value:
if isinstance(item, str):
pieces.append(item)
elif isinstance(item, dict):
_append_text_value(pieces, item.get("text"))
_append_text_value(pieces, item.get("content"))
def chat_messages_to_response_body(body: dict[str, Any]) -> dict[str, Any]:
messages = body.get("messages")
if not isinstance(messages, list):
raise RuntimeError("`messages` must be an array.")
if body.get("tools"):
raise RuntimeError("Codex/ChatGPT account wrapper does not yet support tool calls.")
instructions: list[str] = []
response_input: list[dict[str, Any]] = []
for message in messages:
if not isinstance(message, dict):
continue
role = str(message.get("role") or "user")
content = message.get("content", "")
text = normalize_message_content(content)
if role in {"system", "developer"}:
if text:
instructions.append(text)
continue
response_input.append({"role": role, "content": text})
response_body: dict[str, Any] = {
"model": body.get("model") or "gpt-5.2",
"input": response_input,
"instructions": "\n\n".join(instructions),
"store": False,
}
if body.get("temperature") is not None:
response_body["temperature"] = body["temperature"]
if body.get("top_p") is not None:
response_body["top_p"] = body["top_p"]
if body.get("reasoning_effort") is not None:
response_body["reasoning"] = {"effort": body["reasoning_effort"]}
return response_body
def normalize_message_content(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
elif isinstance(item, str):
parts.append(item)
return "\n".join(parts)
if content is None:
return ""
return str(content)
def response_text(response: dict[str, Any]) -> str:
value = response.get("output_text")
if isinstance(value, str):
return value
pieces: list[str] = []
output = response.get("output")
if isinstance(output, list):
for item in output:
if not isinstance(item, dict):
continue
content = item.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text")
if isinstance(text, str):
pieces.append(text)
return "".join(pieces)
def build_upstream_url(path: str, base_url: str) -> str:
if path.startswith("http://") or path.startswith("https://"):
parsed = urlparse(path)
path = parsed.path
if parsed.query:
path = f"{path}?{parsed.query}"
if path == "/v1":
path = "/"
elif path.startswith("/v1/"):
path = path[3:]
return urljoin(base_url.rstrip("/") + "/", path.lstrip("/"))
def sanitize_forward_headers(headers: dict[str, str]) -> dict[str, str]:
blocked = {
"authorization",
"chatgpt-account-id",
"host",
"openai-beta",
"content-length",
"connection",
}
return {
key: value
for key, value in headers.items()
if key.lower() not in blocked and value is not None
}
def response_headers(response: requests.Response) -> dict[str, str]:
blocked = {
"connection",
"content-encoding",
"content-length",
"transfer-encoding",
}
return {
key: value
for key, value in response.headers.items()
if key.lower() not in blocked
}
def upstream_error_message(response: requests.Response, fallback: str) -> str:
text = response.text
if not text:
return fallback
try:
payload = json.loads(text)
except json.JSONDecodeError:
return text
if isinstance(payload, dict):
detail = payload.get("detail")
if isinstance(detail, str):
return detail
error = payload.get("error")
if isinstance(error, dict) and isinstance(error.get("message"), str):
return error["message"]
if isinstance(error, str):
return error
return text
def resolve_codex_version() -> str:
configured = codex_config()["codex_version"]
if configured:
return configured
try:
result = subprocess.run(
["codex", "--version"],
check=False,
capture_output=True,
text=True,
timeout=2,
)
version = _extract_semver(result.stdout) or _extract_semver(result.stderr)
if version:
return version
except Exception:
pass
return FALLBACK_CODEX_VERSION
def resolve_auth_file_candidates() -> list[Path]:
cfg = codex_config()
explicit = cfg["auth_file_path"]
if explicit:
return [Path(explicit).expanduser()]
candidates: list[Path] = []
for env_name in ("CHATGPT_LOCAL_HOME", "CODEX_HOME"):
env_home = os.getenv(env_name)
if env_home:
candidates.append(Path(env_home).expanduser() / AUTH_FILENAME)
home = Path.home()
candidates.extend(
[
home / ".codex" / AUTH_FILENAME,
home / ".chatgpt-local" / AUTH_FILENAME,
Path(files.get_abs_path("usr", "plugins", "_oauth", "codex", AUTH_FILENAME)),
]
)
return _unique_paths(candidates)
def resolve_auth_write_path() -> Path:
for candidate in resolve_auth_file_candidates():
if candidate.is_file():
return candidate
return resolve_auth_file_candidates()[-1]
def read_auth_file() -> tuple[Path, dict[str, Any]]:
candidates = resolve_auth_file_candidates()
for candidate in candidates:
try:
with candidate.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
if isinstance(payload, dict):
return candidate, payload
except FileNotFoundError:
continue
except Exception:
continue
return resolve_auth_write_path(), {}
def write_auth_file(path: Path, data: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
try:
path.chmod(0o600)
except OSError:
pass
def parse_jwt_claims(token: str) -> dict[str, Any]:
if not token or token.count(".") != 2:
return {}
try:
payload = token.split(".")[1]
padding = "=" * ((4 - len(payload) % 4) % 4)
decoded = base64.urlsafe_b64decode((payload + padding).encode("ascii"))
value = json.loads(decoded)
return value if isinstance(value, dict) else {}
except Exception:
return {}
def derive_account_id(id_token: str) -> str:
return _string(_auth_claims(parse_jwt_claims(id_token)).get("chatgpt_account_id"))
def parse_iso(value: str) -> datetime | None:
if not value:
return None
normalized = value.replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(normalized)
except ValueError:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _auth_claims(claims: dict[str, Any]) -> dict[str, Any]:
return _record(claims.get("https://api.openai.com/auth"))
def _record(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def _string(value: Any) -> str:
return value if isinstance(value, str) else ""
def _jwt_expiration_iso(claims: dict[str, Any]) -> str:
exp = claims.get("exp")
if not isinstance(exp, (int, float)):
return ""
return datetime.fromtimestamp(float(exp), tz=timezone.utc).isoformat()
def _base64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
def _token_error_message(response: requests.Response) -> str:
try:
payload = response.json()
except Exception:
payload = None
if isinstance(payload, dict):
for key in OAUTH_ERROR_KEYS:
value = payload.get(key)
if isinstance(value, str) and value:
return value
error = payload.get("error")
if isinstance(error, dict) and isinstance(error.get("message"), str):
return error["message"]
if isinstance(error, str):
return error
return f"OAuth token endpoint returned status {response.status_code}: {response.text}"
def _extract_semver(value: str) -> str:
import re
match = re.search(r"\b\d+\.\d+\.\d+\b", value or "")
return match.group(0) if match else ""
def _safe_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _device_expires_at(value: Any) -> float:
if isinstance(value, str):
parsed = parse_iso(value)
if parsed is not None:
return parsed.timestamp()
return time.time() + DEVICE_CODE_TIMEOUT_SECONDS
def _unique_paths(paths: list[Path]) -> list[Path]:
result: list[Path] = []
seen: set[str] = set()
for path in paths:
key = str(path)
if key in seen:
continue
seen.add(key)
result.append(path)
return result