diff --git a/extensions/xai/xai-oauth.test.ts b/extensions/xai/xai-oauth.test.ts index 22159b6db8f..65aeb7d3a36 100644 --- a/extensions/xai/xai-oauth.test.ts +++ b/extensions/xai/xai-oauth.test.ts @@ -1,3 +1,6 @@ +import type { ProviderAuthContext } from "openclaw/plugin-sdk/plugin-entry"; +import type { OAuthCredential } from "openclaw/plugin-sdk/provider-auth"; +import { createRuntimeEnv, createTestWizardPrompter } from "openclaw/plugin-sdk/testing"; import { afterEach, describe, expect, it, vi } from "vitest"; import { buildXaiOAuthAuthorizationCodeTokenBody, @@ -97,34 +100,29 @@ describe("xAI OAuth", () => { it("validates discovered endpoints before using them", async () => { vi.stubEnv("OPENCLAW_VERSION", "2026.3.22"); - const fetchImpl = vi.fn(async () => + const fetchImpl = vi.fn(async () => jsonResponse({ authorization_endpoint: "https://auth.x.ai/oauth2/authorize", - device_authorization_endpoint: "https://auth.x.ai/oauth2/device/code", token_endpoint: "https://auth.x.ai/oauth2/token", }), - ) as unknown as typeof fetch; + ); await expect(fetchXaiOAuthDiscovery({ fetchImpl })).resolves.toEqual({ authorizationEndpoint: "https://auth.x.ai/oauth2/authorize", - deviceAuthorizationEndpoint: "https://auth.x.ai/oauth2/device/code", tokenEndpoint: "https://auth.x.ai/oauth2/token", }); - const discoveryInit = (fetchImpl as unknown as ReturnType).mock.calls.at( - 0, - )?.[1] as RequestInit | undefined; + const discoveryInit = fetchImpl.mock.calls.at(0)?.[1]; const discoveryHeaders = new Headers(discoveryInit?.headers ?? {}); expect(discoveryHeaders.get("user-agent")).toBe("openclaw/2026.3.22"); vi.unstubAllEnvs(); - const poisonedFetch = vi.fn(async () => + const poisonedFetch = vi.fn(async () => jsonResponse({ authorization_endpoint: "https://auth.x.ai/oauth2/authorize", - device_authorization_endpoint: "https://auth.x.ai/oauth2/device/code", token_endpoint: "https://evil.test/oauth2/token", }), - ) as unknown as typeof fetch; + ); await expect(fetchXaiOAuthDiscovery({ fetchImpl: poisonedFetch })).rejects.toThrow( "untrusted token endpoint", @@ -133,10 +131,10 @@ describe("xAI OAuth", () => { it("refreshes with the cached token endpoint and preserves refresh fallback", async () => { vi.stubEnv("OPENCLAW_VERSION", "2026.3.22"); - const fetchImpl = vi.fn(async (_url: string | URL | Request, init?: RequestInit) => { + const fetchImpl = vi.fn(async (_url, init) => { expect(init?.method).toBe("POST"); expect(typeof init?.body).toBe("string"); - const body = init?.body as string; + const body = requireStringBody(init); expect(body).toContain("grant_type=refresh_token"); expect(body).toContain(`client_id=${encodeURIComponent(XAI_OAUTH_CLIENT_ID)}`); expect(body).toContain("refresh_token=refresh-1"); @@ -146,19 +144,17 @@ describe("xAI OAuth", () => { access_token: "access-2", expires_in: 120, }); - }) as unknown as typeof fetch; + }); - const refreshed = await refreshXaiOAuthCredential( - { - type: "oauth", - provider: "xai", - access: "access-1", - refresh: "refresh-1", - expires: 100, - tokenEndpoint: "https://auth.x.ai/oauth2/token", - } as unknown as Parameters[0], - { fetchImpl, now: () => 1_000 }, - ); + const credential = { + type: "oauth", + provider: "xai", + access: "access-1", + refresh: "refresh-1", + expires: 100, + tokenEndpoint: "https://auth.x.ai/oauth2/token", + } satisfies OAuthCredential & { tokenEndpoint: string }; + const refreshed = await refreshXaiOAuthCredential(credential, { fetchImpl, now: () => 1_000 }); expect(fetchImpl).toHaveBeenCalledWith("https://auth.x.ai/oauth2/token", expect.any(Object)); expect(refreshed.access).toBe("access-2"); @@ -205,28 +201,30 @@ describe("xAI OAuth", () => { }), ); vi.stubGlobal("fetch", fetchImpl); - const ctx = { + const note = vi.fn(async () => {}); + const openUrl = vi.fn(async () => {}); + const runtime = createRuntimeEnv(); + const ctx: ProviderAuthContext = { config: {}, isRemote: true, - openUrl: vi.fn(async () => {}), - prompter: { + openUrl, + prompter: createTestWizardPrompter({ progress: vi.fn(() => progress), - note: vi.fn(async () => {}), + note, + }), + runtime, + oauth: { + createVpsAwareHandlers: () => { + throw new Error("unexpected VPS OAuth handler request"); + }, }, - runtime: { - log: vi.fn(), - }, - oauth: {}, }; - const result = await loginXaiDeviceCode(ctx as never); + const result = await loginXaiDeviceCode(ctx); - expect(ctx.openUrl).not.toHaveBeenCalled(); - expect(ctx.prompter.note).toHaveBeenCalledWith( - expect.stringContaining("ABCD-1234"), - "xAI device code", - ); - const remoteLog = ctx.runtime.log.mock.calls[0]?.[0]; + expect(openUrl).not.toHaveBeenCalled(); + expect(note).toHaveBeenCalledWith(expect.stringContaining("ABCD-1234"), "xAI device code"); + const remoteLog = runtime.log.mock.calls[0]?.[0]; expect(remoteLog).toContain("https://accounts.x.ai/oauth2/device"); expect(remoteLog).not.toContain("ABCD-1234"); const deviceRequest = fetchImpl.mock.calls[1]?.[1]; @@ -243,8 +241,7 @@ describe("xAI OAuth", () => { ); expect(tokenBody).toContain("device_code=device-code-1"); - const credential = result.profiles[0]?.credential as Record | undefined; - expect(credential).toMatchObject({ + expect(result.profiles[0]?.credential).toMatchObject({ type: "oauth", provider: "xai", refresh: "refresh-1", @@ -255,8 +252,8 @@ describe("xAI OAuth", () => { issuer: "https://auth.x.ai", authFlow: "device-code", accountId: "acct-1", + access: expect.any(String), }); - expect(credential?.access).toEqual(expect.any(String)); expect(progress.update).toHaveBeenCalledWith("Waiting for xAI device authorization..."); expect(progress.stop).toHaveBeenCalledWith("xAI device code complete"); }); diff --git a/extensions/xai/xai-oauth.ts b/extensions/xai/xai-oauth.ts index b43d3ffce19..bea4fc6e862 100644 --- a/extensions/xai/xai-oauth.ts +++ b/extensions/xai/xai-oauth.ts @@ -38,6 +38,10 @@ const XAI_DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code type XaiOAuthDiscovery = { authorizationEndpoint: string; + tokenEndpoint: string; +}; + +type XaiDeviceCodeDiscovery = { deviceAuthorizationEndpoint: string; tokenEndpoint: string; }; @@ -119,9 +123,9 @@ async function readJsonResponse(response: Response, context: string): Promise { +): Promise> { const response = await getFetchImpl(options.fetchImpl)(XAI_OAUTH_DISCOVERY_URL, { headers: { Accept: "application/json", @@ -129,15 +133,16 @@ export async function fetchXaiOAuthDiscovery( }, signal: AbortSignal.timeout(XAI_OAUTH_FETCH_TIMEOUT_MS), }); - const json = readStringRecord(await readJsonResponse(response, "xAI OAuth discovery")); + return readStringRecord(await readJsonResponse(response, "xAI OAuth discovery")); +} + +export async function fetchXaiOAuthDiscovery( + options: XaiOAuthFetchOptions = {}, +): Promise { + const json = await fetchXaiOAuthDiscoveryDocument(options); const authorizationEndpoint = json.authorization_endpoint; - const deviceAuthorizationEndpoint = json.device_authorization_endpoint; const tokenEndpoint = json.token_endpoint; - if ( - typeof authorizationEndpoint !== "string" || - typeof deviceAuthorizationEndpoint !== "string" || - typeof tokenEndpoint !== "string" - ) { + if (typeof authorizationEndpoint !== "string" || typeof tokenEndpoint !== "string") { throw new Error("xAI OAuth discovery response is missing endpoints"); } return { @@ -145,6 +150,20 @@ export async function fetchXaiOAuthDiscovery( authorizationEndpoint, "authorization endpoint", ), + tokenEndpoint: requireTrustedXaiOAuthEndpoint(tokenEndpoint, "token endpoint"), + }; +} + +async function fetchXaiDeviceCodeDiscovery( + options: XaiOAuthFetchOptions = {}, +): Promise { + const json = await fetchXaiOAuthDiscoveryDocument(options); + const deviceAuthorizationEndpoint = json.device_authorization_endpoint; + const tokenEndpoint = json.token_endpoint; + if (typeof deviceAuthorizationEndpoint !== "string" || typeof tokenEndpoint !== "string") { + throw new Error("xAI OAuth discovery response is missing device code endpoints"); + } + return { deviceAuthorizationEndpoint: requireTrustedXaiOAuthEndpoint( deviceAuthorizationEndpoint, "device authorization endpoint", @@ -483,8 +502,11 @@ function resolveXaiOAuthIdentity(tokens: XaiOAuthTokenResponse): XaiOAuthIdentit }; } -function readCredentialString(credential: OAuthCredential, key: string): string | undefined { - const value = (credential as unknown as Record)[key]; +function readCredentialString( + credential: OAuthCredential & Partial>, + key: TKey, +): string | undefined { + const value = credential[key]; return typeof value === "string" && value.trim().length > 0 ? value : undefined; } @@ -591,7 +613,7 @@ async function noteXaiDeviceCode( export async function loginXaiDeviceCode(ctx: ProviderAuthContext): Promise { const progress = ctx.prompter.progress("Starting xAI device code flow..."); try { - const discovery = await fetchXaiOAuthDiscovery(); + const discovery = await fetchXaiDeviceCodeDiscovery(); progress.update("Requesting xAI device code..."); const deviceCode = await requestXaiDeviceCode({ deviceAuthorizationEndpoint: discovery.deviceAuthorizationEndpoint,