diff --git a/packages/opencode/src/mcp/oauth-callback.ts b/packages/opencode/src/mcp/oauth-callback.ts index 3e6169517f..fbb43d3921 100644 --- a/packages/opencode/src/mcp/oauth-callback.ts +++ b/packages/opencode/src/mcp/oauth-callback.ts @@ -56,177 +56,177 @@ interface PendingAuth { timeout: ReturnType } -export namespace McpOAuthCallback { - let server: ReturnType | undefined - const pendingAuths = new Map() - // Reverse index: mcpName → oauthState, so cancelPending(mcpName) can - // find the right entry in pendingAuths (which is keyed by oauthState). - const mcpNameToState = new Map() +let server: ReturnType | undefined +const pendingAuths = new Map() +// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can +// find the right entry in pendingAuths (which is keyed by oauthState). +const mcpNameToState = new Map() - const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes +const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes - function cleanupStateIndex(oauthState: string) { - for (const [name, state] of mcpNameToState) { - if (state === oauthState) { - mcpNameToState.delete(name) - break - } +function cleanupStateIndex(oauthState: string) { + for (const [name, state] of mcpNameToState) { + if (state === oauthState) { + mcpNameToState.delete(name) + break } } - - function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) { - const url = new URL(req.url || "/", `http://localhost:${currentPort}`) - - if (url.pathname !== currentPath) { - res.writeHead(404) - res.end("Not found") - return - } - - const code = url.searchParams.get("code") - const state = url.searchParams.get("state") - const error = url.searchParams.get("error") - const errorDescription = url.searchParams.get("error_description") - - log.info("received oauth callback", { hasCode: !!code, state, error }) - - // Enforce state parameter presence - if (!state) { - const errorMsg = "Missing required state parameter - potential CSRF attack" - log.error("oauth callback missing state parameter", { url: url.toString() }) - res.writeHead(400, { "Content-Type": "text/html" }) - res.end(HTML_ERROR(errorMsg)) - return - } - - if (error) { - const errorMsg = errorDescription || error - if (pendingAuths.has(state)) { - const pending = pendingAuths.get(state)! - clearTimeout(pending.timeout) - pendingAuths.delete(state) - cleanupStateIndex(state) - pending.reject(new Error(errorMsg)) - } - res.writeHead(200, { "Content-Type": "text/html" }) - res.end(HTML_ERROR(errorMsg)) - return - } - - if (!code) { - res.writeHead(400, { "Content-Type": "text/html" }) - res.end(HTML_ERROR("No authorization code provided")) - return - } - - // Validate state parameter - if (!pendingAuths.has(state)) { - const errorMsg = "Invalid or expired state parameter - potential CSRF attack" - log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) - res.writeHead(400, { "Content-Type": "text/html" }) - res.end(HTML_ERROR(errorMsg)) - return - } - - const pending = pendingAuths.get(state)! - - clearTimeout(pending.timeout) - pendingAuths.delete(state) - cleanupStateIndex(state) - pending.resolve(code) - - res.writeHead(200, { "Content-Type": "text/html" }) - res.end(HTML_SUCCESS) - } - - export async function ensureRunning(redirectUri?: string): Promise { - // Parse the redirect URI to get port and path (uses defaults if not provided) - const { port, path } = parseRedirectUri(redirectUri) - - // If server is running on a different port/path, stop it first - if (server && (currentPort !== port || currentPath !== path)) { - log.info("stopping oauth callback server to reconfigure", { oldPort: currentPort, newPort: port }) - await stop() - } - - if (server) return - - const running = await isPortInUse(port) - if (running) { - log.info("oauth callback server already running on another instance", { port }) - return - } - - currentPort = port - currentPath = path - - server = createServer(handleRequest) - await new Promise((resolve, reject) => { - server!.listen(currentPort, () => { - log.info("oauth callback server started", { port: currentPort, path: currentPath }) - resolve() - }) - server!.on("error", reject) - }) - } - - export function waitForCallback(oauthState: string, mcpName?: string): Promise { - if (mcpName) mcpNameToState.set(mcpName, oauthState) - return new Promise((resolve, reject) => { - const timeout = setTimeout(() => { - if (pendingAuths.has(oauthState)) { - pendingAuths.delete(oauthState) - if (mcpName) mcpNameToState.delete(mcpName) - reject(new Error("OAuth callback timeout - authorization took too long")) - } - }, CALLBACK_TIMEOUT_MS) - - pendingAuths.set(oauthState, { resolve, reject, timeout }) - }) - } - - export function cancelPending(mcpName: string): void { - // Look up the oauthState for this mcpName via the reverse index - const oauthState = mcpNameToState.get(mcpName) - const key = oauthState ?? mcpName - const pending = pendingAuths.get(key) - if (pending) { - clearTimeout(pending.timeout) - pendingAuths.delete(key) - mcpNameToState.delete(mcpName) - pending.reject(new Error("Authorization cancelled")) - } - } - - export async function isPortInUse(port: number = OAUTH_CALLBACK_PORT): Promise { - return new Promise((resolve) => { - const socket = createConnection(port, "127.0.0.1") - socket.on("connect", () => { - socket.destroy() - resolve(true) - }) - socket.on("error", () => { - resolve(false) - }) - }) - } - - export async function stop(): Promise { - if (server) { - await new Promise((resolve) => server!.close(() => resolve())) - server = undefined - log.info("oauth callback server stopped") - } - - for (const [_name, pending] of pendingAuths) { - clearTimeout(pending.timeout) - pending.reject(new Error("OAuth callback server stopped")) - } - pendingAuths.clear() - mcpNameToState.clear() - } - - export function isRunning(): boolean { - return server !== undefined - } } + +function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) { + const url = new URL(req.url || "/", `http://localhost:${currentPort}`) + + if (url.pathname !== currentPath) { + res.writeHead(404) + res.end("Not found") + return + } + + const code = url.searchParams.get("code") + const state = url.searchParams.get("state") + const error = url.searchParams.get("error") + const errorDescription = url.searchParams.get("error_description") + + log.info("received oauth callback", { hasCode: !!code, state, error }) + + // Enforce state parameter presence + if (!state) { + const errorMsg = "Missing required state parameter - potential CSRF attack" + log.error("oauth callback missing state parameter", { url: url.toString() }) + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(HTML_ERROR(errorMsg)) + return + } + + if (error) { + const errorMsg = errorDescription || error + if (pendingAuths.has(state)) { + const pending = pendingAuths.get(state)! + clearTimeout(pending.timeout) + pendingAuths.delete(state) + cleanupStateIndex(state) + pending.reject(new Error(errorMsg)) + } + res.writeHead(200, { "Content-Type": "text/html" }) + res.end(HTML_ERROR(errorMsg)) + return + } + + if (!code) { + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(HTML_ERROR("No authorization code provided")) + return + } + + // Validate state parameter + if (!pendingAuths.has(state)) { + const errorMsg = "Invalid or expired state parameter - potential CSRF attack" + log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) }) + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(HTML_ERROR(errorMsg)) + return + } + + const pending = pendingAuths.get(state)! + + clearTimeout(pending.timeout) + pendingAuths.delete(state) + cleanupStateIndex(state) + pending.resolve(code) + + res.writeHead(200, { "Content-Type": "text/html" }) + res.end(HTML_SUCCESS) +} + +export async function ensureRunning(redirectUri?: string): Promise { + // Parse the redirect URI to get port and path (uses defaults if not provided) + const { port, path } = parseRedirectUri(redirectUri) + + // If server is running on a different port/path, stop it first + if (server && (currentPort !== port || currentPath !== path)) { + log.info("stopping oauth callback server to reconfigure", { oldPort: currentPort, newPort: port }) + await stop() + } + + if (server) return + + const running = await isPortInUse(port) + if (running) { + log.info("oauth callback server already running on another instance", { port }) + return + } + + currentPort = port + currentPath = path + + server = createServer(handleRequest) + await new Promise((resolve, reject) => { + server!.listen(currentPort, () => { + log.info("oauth callback server started", { port: currentPort, path: currentPath }) + resolve() + }) + server!.on("error", reject) + }) +} + +export function waitForCallback(oauthState: string, mcpName?: string): Promise { + if (mcpName) mcpNameToState.set(mcpName, oauthState) + return new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + if (pendingAuths.has(oauthState)) { + pendingAuths.delete(oauthState) + if (mcpName) mcpNameToState.delete(mcpName) + reject(new Error("OAuth callback timeout - authorization took too long")) + } + }, CALLBACK_TIMEOUT_MS) + + pendingAuths.set(oauthState, { resolve, reject, timeout }) + }) +} + +export function cancelPending(mcpName: string): void { + // Look up the oauthState for this mcpName via the reverse index + const oauthState = mcpNameToState.get(mcpName) + const key = oauthState ?? mcpName + const pending = pendingAuths.get(key) + if (pending) { + clearTimeout(pending.timeout) + pendingAuths.delete(key) + mcpNameToState.delete(mcpName) + pending.reject(new Error("Authorization cancelled")) + } +} + +export async function isPortInUse(port: number = OAUTH_CALLBACK_PORT): Promise { + return new Promise((resolve) => { + const socket = createConnection(port, "127.0.0.1") + socket.on("connect", () => { + socket.destroy() + resolve(true) + }) + socket.on("error", () => { + resolve(false) + }) + }) +} + +export async function stop(): Promise { + if (server) { + await new Promise((resolve) => server!.close(() => resolve())) + server = undefined + log.info("oauth callback server stopped") + } + + for (const [_name, pending] of pendingAuths) { + clearTimeout(pending.timeout) + pending.reject(new Error("OAuth callback server stopped")) + } + pendingAuths.clear() + mcpNameToState.clear() +} + +export function isRunning(): boolean { + return server !== undefined +} + +export * as McpOAuthCallback from "./oauth-callback"