fix(xai): decouple device code discovery

This commit is contained in:
Ayaan Zaidi 2026-05-19 06:31:44 +00:00
parent 896fd13b1c
commit b66e91ba77
2 changed files with 73 additions and 54 deletions

View file

@ -1,3 +1,6 @@
import type { ProviderAuthContext } from "openclaw/plugin-sdk/plugin-entry";
import type { OAuthCredential } from "openclaw/plugin-sdk/provider-auth";
import { createRuntimeEnv, createTestWizardPrompter } from "openclaw/plugin-sdk/testing";
import { afterEach, describe, expect, it, vi } from "vitest";
import {
buildXaiOAuthAuthorizationCodeTokenBody,
@ -97,34 +100,29 @@ describe("xAI OAuth", () => {
it("validates discovered endpoints before using them", async () => {
vi.stubEnv("OPENCLAW_VERSION", "2026.3.22");
const fetchImpl = vi.fn(async () =>
const fetchImpl = vi.fn<typeof fetch>(async () =>
jsonResponse({
authorization_endpoint: "https://auth.x.ai/oauth2/authorize",
device_authorization_endpoint: "https://auth.x.ai/oauth2/device/code",
token_endpoint: "https://auth.x.ai/oauth2/token",
}),
) as unknown as typeof fetch;
);
await expect(fetchXaiOAuthDiscovery({ fetchImpl })).resolves.toEqual({
authorizationEndpoint: "https://auth.x.ai/oauth2/authorize",
deviceAuthorizationEndpoint: "https://auth.x.ai/oauth2/device/code",
tokenEndpoint: "https://auth.x.ai/oauth2/token",
});
const discoveryInit = (fetchImpl as unknown as ReturnType<typeof vi.fn>).mock.calls.at(
0,
)?.[1] as RequestInit | undefined;
const discoveryInit = fetchImpl.mock.calls.at(0)?.[1];
const discoveryHeaders = new Headers(discoveryInit?.headers ?? {});
expect(discoveryHeaders.get("user-agent")).toBe("openclaw/2026.3.22");
vi.unstubAllEnvs();
const poisonedFetch = vi.fn(async () =>
const poisonedFetch = vi.fn<typeof fetch>(async () =>
jsonResponse({
authorization_endpoint: "https://auth.x.ai/oauth2/authorize",
device_authorization_endpoint: "https://auth.x.ai/oauth2/device/code",
token_endpoint: "https://evil.test/oauth2/token",
}),
) as unknown as typeof fetch;
);
await expect(fetchXaiOAuthDiscovery({ fetchImpl: poisonedFetch })).rejects.toThrow(
"untrusted token endpoint",
@ -133,10 +131,10 @@ describe("xAI OAuth", () => {
it("refreshes with the cached token endpoint and preserves refresh fallback", async () => {
vi.stubEnv("OPENCLAW_VERSION", "2026.3.22");
const fetchImpl = vi.fn(async (_url: string | URL | Request, init?: RequestInit) => {
const fetchImpl = vi.fn<typeof fetch>(async (_url, init) => {
expect(init?.method).toBe("POST");
expect(typeof init?.body).toBe("string");
const body = init?.body as string;
const body = requireStringBody(init);
expect(body).toContain("grant_type=refresh_token");
expect(body).toContain(`client_id=${encodeURIComponent(XAI_OAUTH_CLIENT_ID)}`);
expect(body).toContain("refresh_token=refresh-1");
@ -146,19 +144,17 @@ describe("xAI OAuth", () => {
access_token: "access-2",
expires_in: 120,
});
}) as unknown as typeof fetch;
});
const refreshed = await refreshXaiOAuthCredential(
{
type: "oauth",
provider: "xai",
access: "access-1",
refresh: "refresh-1",
expires: 100,
tokenEndpoint: "https://auth.x.ai/oauth2/token",
} as unknown as Parameters<typeof refreshXaiOAuthCredential>[0],
{ fetchImpl, now: () => 1_000 },
);
const credential = {
type: "oauth",
provider: "xai",
access: "access-1",
refresh: "refresh-1",
expires: 100,
tokenEndpoint: "https://auth.x.ai/oauth2/token",
} satisfies OAuthCredential & { tokenEndpoint: string };
const refreshed = await refreshXaiOAuthCredential(credential, { fetchImpl, now: () => 1_000 });
expect(fetchImpl).toHaveBeenCalledWith("https://auth.x.ai/oauth2/token", expect.any(Object));
expect(refreshed.access).toBe("access-2");
@ -205,28 +201,30 @@ describe("xAI OAuth", () => {
}),
);
vi.stubGlobal("fetch", fetchImpl);
const ctx = {
const note = vi.fn(async () => {});
const openUrl = vi.fn(async () => {});
const runtime = createRuntimeEnv();
const ctx: ProviderAuthContext = {
config: {},
isRemote: true,
openUrl: vi.fn(async () => {}),
prompter: {
openUrl,
prompter: createTestWizardPrompter({
progress: vi.fn(() => progress),
note: vi.fn(async () => {}),
note,
}),
runtime,
oauth: {
createVpsAwareHandlers: () => {
throw new Error("unexpected VPS OAuth handler request");
},
},
runtime: {
log: vi.fn(),
},
oauth: {},
};
const result = await loginXaiDeviceCode(ctx as never);
const result = await loginXaiDeviceCode(ctx);
expect(ctx.openUrl).not.toHaveBeenCalled();
expect(ctx.prompter.note).toHaveBeenCalledWith(
expect.stringContaining("ABCD-1234"),
"xAI device code",
);
const remoteLog = ctx.runtime.log.mock.calls[0]?.[0];
expect(openUrl).not.toHaveBeenCalled();
expect(note).toHaveBeenCalledWith(expect.stringContaining("ABCD-1234"), "xAI device code");
const remoteLog = runtime.log.mock.calls[0]?.[0];
expect(remoteLog).toContain("https://accounts.x.ai/oauth2/device");
expect(remoteLog).not.toContain("ABCD-1234");
const deviceRequest = fetchImpl.mock.calls[1]?.[1];
@ -243,8 +241,7 @@ describe("xAI OAuth", () => {
);
expect(tokenBody).toContain("device_code=device-code-1");
const credential = result.profiles[0]?.credential as Record<string, unknown> | undefined;
expect(credential).toMatchObject({
expect(result.profiles[0]?.credential).toMatchObject({
type: "oauth",
provider: "xai",
refresh: "refresh-1",
@ -255,8 +252,8 @@ describe("xAI OAuth", () => {
issuer: "https://auth.x.ai",
authFlow: "device-code",
accountId: "acct-1",
access: expect.any(String),
});
expect(credential?.access).toEqual(expect.any(String));
expect(progress.update).toHaveBeenCalledWith("Waiting for xAI device authorization...");
expect(progress.stop).toHaveBeenCalledWith("xAI device code complete");
});

View file

@ -38,6 +38,10 @@ const XAI_DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code
type XaiOAuthDiscovery = {
authorizationEndpoint: string;
tokenEndpoint: string;
};
type XaiDeviceCodeDiscovery = {
deviceAuthorizationEndpoint: string;
tokenEndpoint: string;
};
@ -119,9 +123,9 @@ async function readJsonResponse(response: Response, context: string): Promise<un
return body;
}
export async function fetchXaiOAuthDiscovery(
async function fetchXaiOAuthDiscoveryDocument(
options: XaiOAuthFetchOptions = {},
): Promise<XaiOAuthDiscovery> {
): Promise<Record<string, unknown>> {
const response = await getFetchImpl(options.fetchImpl)(XAI_OAUTH_DISCOVERY_URL, {
headers: {
Accept: "application/json",
@ -129,15 +133,16 @@ export async function fetchXaiOAuthDiscovery(
},
signal: AbortSignal.timeout(XAI_OAUTH_FETCH_TIMEOUT_MS),
});
const json = readStringRecord(await readJsonResponse(response, "xAI OAuth discovery"));
return readStringRecord(await readJsonResponse(response, "xAI OAuth discovery"));
}
export async function fetchXaiOAuthDiscovery(
options: XaiOAuthFetchOptions = {},
): Promise<XaiOAuthDiscovery> {
const json = await fetchXaiOAuthDiscoveryDocument(options);
const authorizationEndpoint = json.authorization_endpoint;
const deviceAuthorizationEndpoint = json.device_authorization_endpoint;
const tokenEndpoint = json.token_endpoint;
if (
typeof authorizationEndpoint !== "string" ||
typeof deviceAuthorizationEndpoint !== "string" ||
typeof tokenEndpoint !== "string"
) {
if (typeof authorizationEndpoint !== "string" || typeof tokenEndpoint !== "string") {
throw new Error("xAI OAuth discovery response is missing endpoints");
}
return {
@ -145,6 +150,20 @@ export async function fetchXaiOAuthDiscovery(
authorizationEndpoint,
"authorization endpoint",
),
tokenEndpoint: requireTrustedXaiOAuthEndpoint(tokenEndpoint, "token endpoint"),
};
}
async function fetchXaiDeviceCodeDiscovery(
options: XaiOAuthFetchOptions = {},
): Promise<XaiDeviceCodeDiscovery> {
const json = await fetchXaiOAuthDiscoveryDocument(options);
const deviceAuthorizationEndpoint = json.device_authorization_endpoint;
const tokenEndpoint = json.token_endpoint;
if (typeof deviceAuthorizationEndpoint !== "string" || typeof tokenEndpoint !== "string") {
throw new Error("xAI OAuth discovery response is missing device code endpoints");
}
return {
deviceAuthorizationEndpoint: requireTrustedXaiOAuthEndpoint(
deviceAuthorizationEndpoint,
"device authorization endpoint",
@ -483,8 +502,11 @@ function resolveXaiOAuthIdentity(tokens: XaiOAuthTokenResponse): XaiOAuthIdentit
};
}
function readCredentialString(credential: OAuthCredential, key: string): string | undefined {
const value = (credential as unknown as Record<string, unknown>)[key];
function readCredentialString<TKey extends string>(
credential: OAuthCredential & Partial<Record<TKey, unknown>>,
key: TKey,
): string | undefined {
const value = credential[key];
return typeof value === "string" && value.trim().length > 0 ? value : undefined;
}
@ -591,7 +613,7 @@ async function noteXaiDeviceCode(
export async function loginXaiDeviceCode(ctx: ProviderAuthContext): Promise<ProviderAuthResult> {
const progress = ctx.prompter.progress("Starting xAI device code flow...");
try {
const discovery = await fetchXaiOAuthDiscovery();
const discovery = await fetchXaiDeviceCodeDiscovery();
progress.update("Requesting xAI device code...");
const deviceCode = await requestXaiDeviceCode({
deviceAuthorizationEndpoint: discovery.deviceAuthorizationEndpoint,