diff --git a/packages/cli/src/config/auth.test.ts b/packages/cli/src/config/auth.test.ts index 6f6b584ef..c960e05a7 100644 --- a/packages/cli/src/config/auth.test.ts +++ b/packages/cli/src/config/auth.test.ts @@ -1,41 +1,112 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2025 Qwen Team * SPDX-License-Identifier: Apache-2.0 */ import { AuthType } from '@qwen-code/qwen-code-core'; import { vi } from 'vitest'; import { validateAuthMethod } from './auth.js'; +import * as settings from './settings.js'; vi.mock('./settings.js', () => ({ loadEnvironment: vi.fn(), loadSettings: vi.fn().mockReturnValue({ - merged: vi.fn().mockReturnValue({}), + merged: {}, }), })); describe('validateAuthMethod', () => { beforeEach(() => { vi.resetModules(); + // Reset mock to default + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: {}, + } as ReturnType); }); afterEach(() => { vi.unstubAllEnvs(); + delete process.env['OPENAI_API_KEY']; + delete process.env['CUSTOM_API_KEY']; + delete process.env['GEMINI_API_KEY']; + delete process.env['GEMINI_API_KEY_ALTERED']; + delete process.env['ANTHROPIC_API_KEY']; + delete process.env['ANTHROPIC_BASE_URL']; + delete process.env['GOOGLE_API_KEY']; }); - it('should return null for USE_OPENAI', () => { + it('should return null for USE_OPENAI with default env key', () => { process.env['OPENAI_API_KEY'] = 'fake-key'; expect(validateAuthMethod(AuthType.USE_OPENAI)).toBeNull(); }); - it('should return an error message for USE_OPENAI if OPENAI_API_KEY is not set', () => { - delete process.env['OPENAI_API_KEY']; + it('should return an error message for USE_OPENAI if no API key is available', () => { expect(validateAuthMethod(AuthType.USE_OPENAI)).toBe( - "Missing API key for OpenAI-compatible auth. Set settings.security.auth.apiKey, or set the 'OPENAI_API_KEY' environment variable. If you configured a model in settings.modelProviders with an envKey, set that env var as well.", + "Missing API key for OpenAI-compatible auth. Set settings.security.auth.apiKey, or set the 'OPENAI_API_KEY' environment variable.", ); }); + it('should return null for USE_OPENAI with custom envKey from modelProviders', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'custom-model' }, + modelProviders: { + openai: [{ id: 'custom-model', envKey: 'CUSTOM_API_KEY' }], + }, + }, + } as unknown as ReturnType); + process.env['CUSTOM_API_KEY'] = 'custom-key'; + + expect(validateAuthMethod(AuthType.USE_OPENAI)).toBeNull(); + }); + + it('should return error with custom envKey hint when modelProviders envKey is set but env var is missing', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'custom-model' }, + modelProviders: { + openai: [{ id: 'custom-model', envKey: 'CUSTOM_API_KEY' }], + }, + }, + } as unknown as ReturnType); + + const result = validateAuthMethod(AuthType.USE_OPENAI); + expect(result).toContain('CUSTOM_API_KEY'); + }); + + it('should return null for USE_GEMINI with custom envKey', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'gemini-1.5-flash' }, + modelProviders: { + gemini: [ + { id: 'gemini-1.5-flash', envKey: 'GEMINI_API_KEY_ALTERED' }, + ], + }, + }, + } as unknown as ReturnType); + process.env['GEMINI_API_KEY_ALTERED'] = 'altered-key'; + + expect(validateAuthMethod(AuthType.USE_GEMINI)).toBeNull(); + }); + + it('should return error with custom envKey for USE_GEMINI when env var is missing', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'gemini-1.5-flash' }, + modelProviders: { + gemini: [ + { id: 'gemini-1.5-flash', envKey: 'GEMINI_API_KEY_ALTERED' }, + ], + }, + }, + } as unknown as ReturnType); + + const result = validateAuthMethod(AuthType.USE_GEMINI); + expect(result).toContain('GEMINI_API_KEY_ALTERED'); + }); + it('should return null for QWEN_OAUTH', () => { expect(validateAuthMethod(AuthType.QWEN_OAUTH)).toBeNull(); }); @@ -45,4 +116,55 @@ describe('validateAuthMethod', () => { 'Invalid auth method selected.', ); }); + + it('should return null for USE_ANTHROPIC with custom envKey and baseUrl', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'claude-3' }, + modelProviders: { + anthropic: [ + { + id: 'claude-3', + envKey: 'CUSTOM_ANTHROPIC_KEY', + baseUrl: 'https://api.anthropic.com', + }, + ], + }, + }, + } as unknown as ReturnType); + process.env['CUSTOM_ANTHROPIC_KEY'] = 'custom-anthropic-key'; + + expect(validateAuthMethod(AuthType.USE_ANTHROPIC)).toBeNull(); + }); + + it('should return error for USE_ANTHROPIC when baseUrl is missing', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'claude-3' }, + modelProviders: { + anthropic: [{ id: 'claude-3', envKey: 'CUSTOM_ANTHROPIC_KEY' }], + }, + }, + } as unknown as ReturnType); + process.env['CUSTOM_ANTHROPIC_KEY'] = 'custom-key'; + + const result = validateAuthMethod(AuthType.USE_ANTHROPIC); + expect(result).toContain('ANTHROPIC_BASE_URL'); + }); + + it('should return null for USE_VERTEX_AI with custom envKey', () => { + vi.mocked(settings.loadSettings).mockReturnValue({ + merged: { + model: { name: 'vertex-model' }, + modelProviders: { + 'vertex-ai': [ + { id: 'vertex-model', envKey: 'GOOGLE_API_KEY_VERTEX' }, + ], + }, + }, + } as unknown as ReturnType); + process.env['GOOGLE_API_KEY_VERTEX'] = 'vertex-key'; + + expect(validateAuthMethod(AuthType.USE_VERTEX_AI)).toBeNull(); + }); }); diff --git a/packages/cli/src/config/auth.ts b/packages/cli/src/config/auth.ts index 42fbf280f..e05b029d9 100644 --- a/packages/cli/src/config/auth.ts +++ b/packages/cli/src/config/auth.ts @@ -1,24 +1,97 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2025 Qwen Team * SPDX-License-Identifier: Apache-2.0 */ import { AuthType } from '@qwen-code/qwen-code-core'; -import { loadEnvironment, loadSettings } from './settings.js'; +import type { + ModelProvidersConfig, + ProviderModelConfig, +} from '@qwen-code/qwen-code-core'; +import { loadEnvironment, loadSettings, type Settings } from './settings.js'; + +/** + * Default environment variable names for each auth type + */ +const DEFAULT_ENV_KEYS: Record = { + [AuthType.USE_OPENAI]: 'OPENAI_API_KEY', + [AuthType.USE_ANTHROPIC]: 'ANTHROPIC_API_KEY', + [AuthType.USE_GEMINI]: 'GEMINI_API_KEY', + [AuthType.USE_VERTEX_AI]: 'GOOGLE_API_KEY', +}; + +/** + * Find model configuration from modelProviders by authType and modelId + */ +function findModelConfig( + modelProviders: ModelProvidersConfig | undefined, + authType: string, + modelId: string | undefined, +): ProviderModelConfig | undefined { + if (!modelProviders || !modelId) { + return undefined; + } + + const models = modelProviders[authType]; + if (!Array.isArray(models)) { + return undefined; + } + + return models.find((m) => m.id === modelId); +} + +/** + * Check if API key is available for the given auth type and model configuration. + * Prioritizes custom envKey from modelProviders over default environment variables. + */ +function hasApiKeyForAuth( + authType: string, + settings: Settings, +): { hasKey: boolean; checkedEnvKey: string | undefined } { + const modelProviders = settings.modelProviders as + | ModelProvidersConfig + | undefined; + const modelId = settings.model?.name; + + // Try to find model-specific envKey from modelProviders + const modelConfig = findModelConfig(modelProviders, authType, modelId); + if (modelConfig?.envKey) { + const hasKey = !!process.env[modelConfig.envKey]; + return { hasKey, checkedEnvKey: modelConfig.envKey }; + } + + // Fallback to default environment variable + const defaultEnvKey = DEFAULT_ENV_KEYS[authType]; + if (defaultEnvKey) { + const hasKey = !!process.env[defaultEnvKey]; + return { hasKey, checkedEnvKey: defaultEnvKey }; + } + + // Also check settings.security.auth.apiKey as fallback + if (settings.security?.auth?.apiKey) { + return { hasKey: true, checkedEnvKey: undefined }; + } + + return { hasKey: false, checkedEnvKey: undefined }; +} export function validateAuthMethod(authMethod: string): string | null { const settings = loadSettings(); loadEnvironment(settings.merged); if (authMethod === AuthType.USE_OPENAI) { - const hasApiKey = - process.env['OPENAI_API_KEY'] || settings.merged.security?.auth?.apiKey; - if (!hasApiKey) { + const { hasKey, checkedEnvKey } = hasApiKeyForAuth( + authMethod, + settings.merged, + ); + if (!hasKey) { + const envKeyHint = checkedEnvKey + ? `'${checkedEnvKey}'` + : "'OPENAI_API_KEY' (or configure modelProviders[].envKey)"; return ( 'Missing API key for OpenAI-compatible auth. ' + - "Set settings.security.auth.apiKey, or set the 'OPENAI_API_KEY' environment variable. " + - 'If you configured a model in settings.modelProviders with an envKey, set that env var as well.' + `Set settings.security.auth.apiKey, or set the ${envKeyHint} environment variable.` ); } return null; @@ -31,31 +104,50 @@ export function validateAuthMethod(authMethod: string): string | null { } if (authMethod === AuthType.USE_ANTHROPIC) { - const hasApiKey = process.env['ANTHROPIC_API_KEY']; - if (!hasApiKey) { - return 'ANTHROPIC_API_KEY environment variable not found.'; + const { hasKey, checkedEnvKey } = hasApiKeyForAuth( + authMethod, + settings.merged, + ); + if (!hasKey) { + const envKeyHint = checkedEnvKey || 'ANTHROPIC_API_KEY'; + return `${envKeyHint} environment variable not found.`; } - const hasBaseUrl = process.env['ANTHROPIC_BASE_URL']; + // Check baseUrl - can come from modelProviders or environment + const modelProviders = settings.merged.modelProviders as + | ModelProvidersConfig + | undefined; + const modelId = settings.merged.model?.name; + const modelConfig = findModelConfig(modelProviders, authMethod, modelId); + const hasBaseUrl = + modelConfig?.baseUrl || process.env['ANTHROPIC_BASE_URL']; if (!hasBaseUrl) { - return 'ANTHROPIC_BASE_URL environment variable not found.'; + return 'ANTHROPIC_BASE_URL environment variable not found (or configure modelProviders[].baseUrl).'; } return null; } if (authMethod === AuthType.USE_GEMINI) { - const hasApiKey = process.env['GEMINI_API_KEY']; - if (!hasApiKey) { - return 'GEMINI_API_KEY environment variable not found. Please set it in your .env file or environment variables.'; + const { hasKey, checkedEnvKey } = hasApiKeyForAuth( + authMethod, + settings.merged, + ); + if (!hasKey) { + const envKeyHint = checkedEnvKey || 'GEMINI_API_KEY'; + return `${envKeyHint} environment variable not found. Please set it in your .env file or environment variables.`; } return null; } if (authMethod === AuthType.USE_VERTEX_AI) { - const hasApiKey = process.env['GOOGLE_API_KEY']; - if (!hasApiKey) { - return 'GOOGLE_API_KEY environment variable not found. Please set it in your .env file or environment variables.'; + const { hasKey, checkedEnvKey } = hasApiKeyForAuth( + authMethod, + settings.merged, + ); + if (!hasKey) { + const envKeyHint = checkedEnvKey || 'GOOGLE_API_KEY'; + return `${envKeyHint} environment variable not found. Please set it in your .env file or environment variables.`; } process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true'; diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 6f2019e75..850d4a822 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -77,10 +77,8 @@ vi.mock('read-package-up', () => ({ ), })); -vi.mock('@qwen-code/qwen-code-core', async () => { - const actualServer = await vi.importActual( - '@qwen-code/qwen-code-core', - ); +vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => { + const actualServer = await importOriginal(); return { ...actualServer, IdeClient: { diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index bc5da7bfc..9fffe8fae 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -31,10 +31,7 @@ import { } from '@qwen-code/qwen-code-core'; import { extensionsCommand } from '../commands/extensions.js'; import type { Settings } from './settings.js'; -import { - buildGenerationConfigSources, - getModelProvidersConfigFromSettings, -} from '../utils/modelProviderUtils.js'; +import { resolveCliGenerationConfig } from '../utils/modelConfigUtils.js'; import yargs, { type Argv } from 'yargs'; import { hideBin } from 'yargs/helpers'; import * as fs from 'node:fs'; @@ -930,26 +927,21 @@ export async function loadCliConfig( (argv.authType as AuthType | undefined) || settings.security?.auth?.selectedType; - const apiKey = - (selectedAuthType === AuthType.USE_OPENAI - ? argv.openaiApiKey || - process.env['OPENAI_API_KEY'] || - settings.security?.auth?.apiKey - : '') || ''; - const baseUrl = - (selectedAuthType === AuthType.USE_OPENAI - ? argv.openaiBaseUrl || - process.env['OPENAI_BASE_URL'] || - settings.security?.auth?.baseUrl - : '') || ''; - const resolvedModel = - argv.model || - (selectedAuthType === AuthType.USE_OPENAI - ? process.env['OPENAI_MODEL'] || - process.env['QWEN_MODEL'] || - settings.model?.name - : '') || - ''; + // Unified resolution of generation config with source attribution + const resolvedCliConfig = resolveCliGenerationConfig({ + argv: { + model: argv.model, + openaiApiKey: argv.openaiApiKey, + openaiBaseUrl: argv.openaiBaseUrl, + openaiLogging: argv.openaiLogging, + openaiLoggingDir: argv.openaiLoggingDir, + }, + settings, + selectedAuthType, + env: process.env as Record, + }); + + const { model: resolvedModel } = resolvedCliConfig; const sandboxConfig = await loadSandboxConfig(settings, argv); const screenReader = @@ -983,17 +975,7 @@ export async function loadCliConfig( } } - const modelProvidersConfig = getModelProvidersConfigFromSettings(settings); - const generationConfigSources = buildGenerationConfigSources({ - argv: { - model: argv.model, - openaiApiKey: argv.openaiApiKey, - openaiBaseUrl: argv.openaiBaseUrl, - }, - settings, - selectedAuthType, - env: process.env as Record, - }); + const modelProvidersConfig = settings.modelProviders; return new Config({ sessionId, @@ -1053,25 +1035,10 @@ export async function loadCliConfig( outputFormat, includePartialMessages, modelProvidersConfig, - generationConfigSources, - generationConfig: { - ...(settings.model?.generationConfig || {}), - model: resolvedModel, - apiKey, - baseUrl, - enableOpenAILogging: - (typeof argv.openaiLogging === 'undefined' - ? settings.model?.enableOpenAILogging - : argv.openaiLogging) ?? false, - openAILoggingDir: - argv.openaiLoggingDir || settings.model?.openAILoggingDir, - }, + generationConfigSources: resolvedCliConfig.sources, + generationConfig: resolvedCliConfig.generationConfig, cliVersion: await getCliVersion(), - webSearch: buildWebSearchConfig( - argv, - settings, - settings.security?.auth?.selectedType, - ), + webSearch: buildWebSearchConfig(argv, settings, selectedAuthType), summarizeToolOutput: settings.model?.summarizeToolOutput, ideMode, chatCompression: settings.model?.chatCompression, diff --git a/packages/cli/src/config/modelProvidersScope.test.ts b/packages/cli/src/config/modelProvidersScope.test.ts new file mode 100644 index 000000000..2b270d6be --- /dev/null +++ b/packages/cli/src/config/modelProvidersScope.test.ts @@ -0,0 +1,89 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, expect, it } from 'vitest'; +import { SettingScope } from './settings.js'; +import { getPersistScopeForModelSelection } from './modelProvidersScope.js'; + +function makeSettings({ + isTrusted, + userModelProviders, + workspaceModelProviders, +}: { + isTrusted: boolean; + userModelProviders?: unknown; + workspaceModelProviders?: unknown; +}) { + const userSettings: Record = {}; + const workspaceSettings: Record = {}; + + // When undefined, treat as "not present in this scope" (the key is omitted), + // matching how LoadedSettings is shaped when a settings file doesn't define it. + if (userModelProviders !== undefined) { + userSettings['modelProviders'] = userModelProviders; + } + if (workspaceModelProviders !== undefined) { + workspaceSettings['modelProviders'] = workspaceModelProviders; + } + + return { + isTrusted, + user: { settings: userSettings }, + workspace: { settings: workspaceSettings }, + } as unknown as import('./settings.js').LoadedSettings; +} + +describe('getPersistScopeForModelSelection', () => { + it('prefers workspace when trusted and workspace defines modelProviders', () => { + const settings = makeSettings({ + isTrusted: true, + workspaceModelProviders: {}, + userModelProviders: { anything: true }, + }); + + expect(getPersistScopeForModelSelection(settings)).toBe( + SettingScope.Workspace, + ); + }); + + it('falls back to user when workspace does not define modelProviders', () => { + const settings = makeSettings({ + isTrusted: true, + workspaceModelProviders: undefined, + userModelProviders: {}, + }); + + expect(getPersistScopeForModelSelection(settings)).toBe(SettingScope.User); + }); + + it('ignores workspace modelProviders when workspace is untrusted', () => { + const settings = makeSettings({ + isTrusted: false, + workspaceModelProviders: {}, + userModelProviders: undefined, + }); + + expect(getPersistScopeForModelSelection(settings)).toBe(SettingScope.User); + }); + + it('falls back to legacy trust heuristic when neither scope defines modelProviders', () => { + const trusted = makeSettings({ + isTrusted: true, + userModelProviders: undefined, + workspaceModelProviders: undefined, + }); + expect(getPersistScopeForModelSelection(trusted)).toBe( + SettingScope.Workspace, + ); + + const untrusted = makeSettings({ + isTrusted: false, + userModelProviders: undefined, + workspaceModelProviders: undefined, + }); + expect(getPersistScopeForModelSelection(untrusted)).toBe(SettingScope.User); + }); +}); diff --git a/packages/cli/src/config/modelProvidersScope.ts b/packages/cli/src/config/modelProvidersScope.ts new file mode 100644 index 000000000..136141103 --- /dev/null +++ b/packages/cli/src/config/modelProvidersScope.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { SettingScope, type LoadedSettings } from './settings.js'; + +function hasOwnModelProviders(settingsObj: unknown): boolean { + if (!settingsObj || typeof settingsObj !== 'object') { + return false; + } + const obj = settingsObj as Record; + // Treat an explicitly configured empty object (modelProviders: {}) as "owned" + // by this scope, which is important when mergeStrategy is REPLACE. + return Object.prototype.hasOwnProperty.call(obj, 'modelProviders'); +} + +/** + * Returns which writable scope (Workspace/User) owns the effective modelProviders + * configuration. + * + * Note: Workspace scope is only considered when the workspace is trusted. + */ +export function getModelProvidersOwnerScope( + settings: LoadedSettings, +): SettingScope | undefined { + if (settings.isTrusted && hasOwnModelProviders(settings.workspace.settings)) { + return SettingScope.Workspace; + } + + if (hasOwnModelProviders(settings.user.settings)) { + return SettingScope.User; + } + + return undefined; +} + +/** + * Choose the settings scope to persist a model selection. + * Prefer persisting back to the scope that contains the effective modelProviders + * config, otherwise fall back to the legacy trust-based heuristic. + */ +export function getPersistScopeForModelSelection( + settings: LoadedSettings, +): SettingScope { + return getModelProvidersOwnerScope(settings) ?? SettingScope.User; +} diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 4562546ff..74b63a7b9 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -113,7 +113,7 @@ const SETTINGS_SCHEMA = { description: 'Model providers configuration grouped by authType. Each authType contains an array of model configurations.', showInDialog: false, - mergeStrategy: MergeStrategy.SHALLOW_MERGE, + mergeStrategy: MergeStrategy.REPLACE, }, general: { diff --git a/packages/cli/src/core/initializer.ts b/packages/cli/src/core/initializer.ts index 5aa3d9e3b..062c0b516 100644 --- a/packages/cli/src/core/initializer.ts +++ b/packages/cli/src/core/initializer.ts @@ -45,7 +45,9 @@ export async function initializeApp( // Auto-detect and set LLM output language on first use initializeLlmOutputLanguage(); - const authType = settings.merged.security?.auth?.selectedType; + // Use authType from modelsConfig which respects CLI --auth-type argument + // over settings.security.auth.selectedType + const authType = config.modelsConfig.getCurrentAuthType(); const authError = await performInitialAuth(config, authType); // Fallback to user select when initial authentication fails @@ -58,8 +60,13 @@ export async function initializeApp( } const themeError = validateTheme(settings); + // Open auth dialog if: + // 1. No authType was explicitly selected (neither from CLI --auth-type nor settings), OR + // 2. Authentication failed + // wasAuthTypeExplicitlyProvided() returns true if CLI or settings specified authType, + // false if using the default QWEN_OAUTH const shouldOpenAuthDialog = - settings.merged.security?.auth?.selectedType === undefined || !!authError; + !config.modelsConfig.wasAuthTypeExplicitlyProvided() || !!authError; if (config.getIdeMode()) { const ideClient = await IdeClient.getInstance(); diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 38dad449c..1449a7f4b 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -32,7 +32,6 @@ import { type Config, type IdeInfo, type IdeContext, - DEFAULT_GEMINI_FLASH_MODEL, IdeClient, ideContextStore, getErrorMessage, @@ -180,15 +179,10 @@ export const AppContainer = (props: AppContainerProps) => { [], ); - // Helper to determine the effective model, considering the fallback state. - const getEffectiveModel = useCallback(() => { - if (config.isInFallbackMode()) { - return DEFAULT_GEMINI_FLASH_MODEL; - } - return config.getModel(); - }, [config]); + // Helper to determine the current model (polled, since Config has no model-change event). + const getCurrentModel = useCallback(() => config.getModel(), [config]); - const [currentModel, setCurrentModel] = useState(getEffectiveModel()); + const [currentModel, setCurrentModel] = useState(getCurrentModel()); const [isConfigInitialized, setConfigInitialized] = useState(false); @@ -241,12 +235,12 @@ export const AppContainer = (props: AppContainerProps) => { [historyManager.addItem], ); - // Watch for model changes (e.g., from Flash fallback) + // Watch for model changes (e.g., user switches model via /model) useEffect(() => { const checkModelChange = () => { - const effectiveModel = getEffectiveModel(); - if (effectiveModel !== currentModel) { - setCurrentModel(effectiveModel); + const model = getCurrentModel(); + if (model !== currentModel) { + setCurrentModel(model); } }; @@ -254,7 +248,7 @@ export const AppContainer = (props: AppContainerProps) => { const interval = setInterval(checkModelChange, 1000); // Check every second return () => clearInterval(interval); - }, [config, currentModel, getEffectiveModel]); + }, [config, currentModel, getCurrentModel]); const { consoleMessages, diff --git a/packages/cli/src/ui/auth/useAuth.ts b/packages/cli/src/ui/auth/useAuth.ts index c13f33c95..6125ebdf2 100644 --- a/packages/cli/src/ui/auth/useAuth.ts +++ b/packages/cli/src/ui/auth/useAuth.ts @@ -8,7 +8,6 @@ import type { Config } from '@qwen-code/qwen-code-core'; import { AuthEvent, AuthType, - clearCachedCredentialFile, getErrorMessage, logAuth, } from '@qwen-code/qwen-code-core'; @@ -109,7 +108,6 @@ export const useAuthCommand = ( if (credentials?.model != null) { settings.setValue(scope, 'model.name', credentials.model); } - await clearCachedCredentialFile(); } } catch (error) { handleAuthFailure(error); diff --git a/packages/cli/src/ui/commands/modelCommand.test.ts b/packages/cli/src/ui/commands/modelCommand.test.ts index af5c2ce63..41f95f199 100644 --- a/packages/cli/src/ui/commands/modelCommand.test.ts +++ b/packages/cli/src/ui/commands/modelCommand.test.ts @@ -13,12 +13,6 @@ import { type ContentGeneratorConfig, type Config, } from '@qwen-code/qwen-code-core'; -import * as availableModelsModule from '../models/availableModels.js'; - -// Mock the availableModels module -vi.mock('../models/availableModels.js', () => ({ - getAvailableModelsForAuthType: vi.fn(), -})); // Helper function to create a mock config function createMockConfig( @@ -31,9 +25,6 @@ function createMockConfig( describe('modelCommand', () => { let mockContext: CommandContext; - const mockGetAvailableModelsForAuthType = vi.mocked( - availableModelsModule.getAvailableModelsForAuthType, - ); beforeEach(() => { mockContext = createMockCommandContext(); @@ -87,10 +78,6 @@ describe('modelCommand', () => { }); it('should return dialog action for QWEN_OAUTH auth type', async () => { - mockGetAvailableModelsForAuthType.mockReturnValue([ - { id: 'qwen3-coder-plus', label: 'qwen3-coder-plus' }, - ]); - const mockConfig = createMockConfig({ model: 'test-model', authType: AuthType.QWEN_OAUTH, @@ -105,11 +92,7 @@ describe('modelCommand', () => { }); }); - it('should return dialog action for USE_OPENAI auth type when model is available', async () => { - mockGetAvailableModelsForAuthType.mockReturnValue([ - { id: 'gpt-4', label: 'gpt-4' }, - ]); - + it('should return dialog action for USE_OPENAI auth type', async () => { const mockConfig = createMockConfig({ model: 'test-model', authType: AuthType.USE_OPENAI, @@ -124,28 +107,7 @@ describe('modelCommand', () => { }); }); - it('should return error for USE_OPENAI auth type when no model is available', async () => { - mockGetAvailableModelsForAuthType.mockReturnValue([]); - - const mockConfig = createMockConfig({ - model: 'test-model', - authType: AuthType.USE_OPENAI, - }); - mockContext.services.config = mockConfig as Config; - - const result = await modelCommand.action!(mockContext, ''); - - expect(result).toEqual({ - type: 'message', - messageType: 'error', - content: - 'No models available for the current authentication type (openai).', - }); - }); - - it('should return error for unsupported auth types', async () => { - mockGetAvailableModelsForAuthType.mockReturnValue([]); - + it('should return dialog action for unsupported auth types', async () => { const mockConfig = createMockConfig({ model: 'test-model', authType: 'UNSUPPORTED_AUTH_TYPE' as AuthType, @@ -155,10 +117,8 @@ describe('modelCommand', () => { const result = await modelCommand.action!(mockContext, ''); expect(result).toEqual({ - type: 'message', - messageType: 'error', - content: - 'No models available for the current authentication type (UNSUPPORTED_AUTH_TYPE).', + type: 'dialog', + dialog: 'model', }); }); diff --git a/packages/cli/src/ui/commands/modelCommand.ts b/packages/cli/src/ui/commands/modelCommand.ts index a25e96a19..e0971bdde 100644 --- a/packages/cli/src/ui/commands/modelCommand.ts +++ b/packages/cli/src/ui/commands/modelCommand.ts @@ -11,7 +11,6 @@ import type { MessageActionReturn, } from './types.js'; import { CommandKind } from './types.js'; -import { getAvailableModelsForAuthType } from '../models/availableModels.js'; import { t } from '../../i18n/index.js'; export const modelCommand: SlashCommand = { @@ -52,22 +51,6 @@ export const modelCommand: SlashCommand = { }; } - const availableModels = getAvailableModelsForAuthType(authType); - - if (availableModels.length === 0) { - return { - type: 'message', - messageType: 'error', - content: t( - 'No models available for the current authentication type ({{authType}}).', - { - authType, - }, - ), - }; - } - - // Trigger model selection dialog return { type: 'dialog', dialog: 'model', diff --git a/packages/cli/src/ui/components/ModelDialog.test.tsx b/packages/cli/src/ui/components/ModelDialog.test.tsx index fe484e260..ac47ba46a 100644 --- a/packages/cli/src/ui/components/ModelDialog.test.tsx +++ b/packages/cli/src/ui/components/ModelDialog.test.tsx @@ -10,7 +10,11 @@ import { ModelDialog } from './ModelDialog.js'; import { useKeypress } from '../hooks/useKeypress.js'; import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js'; import { ConfigContext } from '../contexts/ConfigContext.js'; +import { SettingsContext } from '../contexts/SettingsContext.js'; import type { Config } from '@qwen-code/qwen-code-core'; +import { AuthType } from '@qwen-code/qwen-code-core'; +import type { LoadedSettings } from '../../config/settings.js'; +import { SettingScope } from '../../config/settings.js'; import { AVAILABLE_MODELS_QWEN, MAINLINE_CODER, @@ -36,18 +40,29 @@ const renderComponent = ( }; const combinedProps = { ...defaultProps, ...props }; + const mockSettings = { + isTrusted: true, + user: { settings: {} }, + workspace: { settings: {} }, + setValue: vi.fn(), + } as unknown as LoadedSettings; + const mockConfig = contextValue ? ({ // --- Functions used by ModelDialog --- getModel: vi.fn(() => MAINLINE_CODER), - setModel: vi.fn(), + setModel: vi.fn().mockResolvedValue(undefined), + switchModel: vi.fn().mockResolvedValue(undefined), getAuthType: vi.fn(() => 'qwen-oauth'), // --- Functions used by ClearcutLogger --- getUsageStatisticsEnabled: vi.fn(() => true), getSessionId: vi.fn(() => 'mock-session-id'), getDebugMode: vi.fn(() => false), - getContentGeneratorConfig: vi.fn(() => ({ authType: 'mock' })), + getContentGeneratorConfig: vi.fn(() => ({ + authType: AuthType.QWEN_OAUTH, + model: MAINLINE_CODER, + })), getUseSmartEdit: vi.fn(() => false), getUseModelRouter: vi.fn(() => false), getProxy: vi.fn(() => undefined), @@ -58,21 +73,27 @@ const renderComponent = ( : undefined; const renderResult = render( - - - , + + + + + , ); return { ...renderResult, props: combinedProps, mockConfig, + mockSettings, }; }; describe('', () => { beforeEach(() => { vi.clearAllMocks(); + // Ensure env-based fallback models don't leak into this suite from the developer environment. + delete process.env['OPENAI_MODEL']; + delete process.env['ANTHROPIC_MODEL']; }); afterEach(() => { @@ -91,8 +112,12 @@ describe('', () => { const props = mockedSelect.mock.calls[0][0]; expect(props.items).toHaveLength(AVAILABLE_MODELS_QWEN.length); - expect(props.items[0].value).toBe(MAINLINE_CODER); - expect(props.items[1].value).toBe(MAINLINE_VLM); + expect(props.items[0].value).toBe( + `${AuthType.QWEN_OAUTH}::${MAINLINE_CODER}`, + ); + expect(props.items[1].value).toBe( + `${AuthType.QWEN_OAUTH}::${MAINLINE_VLM}`, + ); expect(props.showNumbers).toBe(true); }); @@ -139,16 +164,93 @@ describe('', () => { expect(mockedSelect).toHaveBeenCalledTimes(1); }); - it('calls config.setModel and onClose when DescriptiveRadioButtonSelect.onSelect is triggered', () => { - const { props, mockConfig } = renderComponent({}, {}); // Pass empty object for contextValue + it('calls config.switchModel and onClose when DescriptiveRadioButtonSelect.onSelect is triggered', async () => { + const { props, mockConfig, mockSettings } = renderComponent({}, {}); // Pass empty object for contextValue const childOnSelect = mockedSelect.mock.calls[0][0].onSelect; expect(childOnSelect).toBeDefined(); - childOnSelect(MAINLINE_CODER); + await childOnSelect(`${AuthType.QWEN_OAUTH}::${MAINLINE_CODER}`); - // Assert against the default mock provided by renderComponent - expect(mockConfig?.setModel).toHaveBeenCalledWith(MAINLINE_CODER); + expect(mockConfig?.switchModel).toHaveBeenCalledWith( + AuthType.QWEN_OAUTH, + MAINLINE_CODER, + undefined, + { + reason: 'user_manual', + context: 'Model switched via /model dialog', + }, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.Workspace, + 'model.name', + MAINLINE_CODER, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.Workspace, + 'security.auth.selectedType', + AuthType.QWEN_OAUTH, + ); + expect(props.onClose).toHaveBeenCalledTimes(1); + }); + + it('calls config.switchModel and persists authType+model when selecting a different authType', async () => { + const switchModel = vi.fn().mockResolvedValue(undefined); + const getAuthType = vi.fn(() => AuthType.USE_OPENAI); + const getAvailableModelsForAuthType = vi.fn((t: AuthType) => { + if (t === AuthType.USE_OPENAI) { + return [{ id: 'gpt-4', label: 'GPT-4', authType: t }]; + } + if (t === AuthType.QWEN_OAUTH) { + return AVAILABLE_MODELS_QWEN.map((m) => ({ + id: m.id, + label: m.label, + authType: AuthType.QWEN_OAUTH, + })); + } + return []; + }); + + const mockConfigWithSwitchAuthType = { + getAuthType, + getModel: vi.fn(() => 'gpt-4'), + getContentGeneratorConfig: vi.fn(() => ({ + authType: AuthType.QWEN_OAUTH, + model: MAINLINE_CODER, + })), + // Add switchModel to the mock object (not the type) + switchModel, + getAvailableModelsForAuthType, + }; + + const { props, mockSettings } = renderComponent( + {}, + // Cast to Config to bypass type checking, matching the runtime behavior + mockConfigWithSwitchAuthType as unknown as Partial, + ); + + const childOnSelect = mockedSelect.mock.calls[0][0].onSelect; + await childOnSelect(`${AuthType.QWEN_OAUTH}::${MAINLINE_CODER}`); + + expect(switchModel).toHaveBeenCalledWith( + AuthType.QWEN_OAUTH, + MAINLINE_CODER, + { requireCachedCredentials: true }, + { + reason: 'user_manual', + context: 'AuthType+model switched via /model dialog', + }, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.Workspace, + 'model.name', + MAINLINE_CODER, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.Workspace, + 'security.auth.selectedType', + AuthType.QWEN_OAUTH, + ); expect(props.onClose).toHaveBeenCalledTimes(1); }); @@ -193,17 +295,25 @@ describe('', () => { it('updates initialIndex when config context changes', () => { const mockGetModel = vi.fn(() => MAINLINE_CODER); const mockGetAuthType = vi.fn(() => 'qwen-oauth'); + const mockSettings = { + isTrusted: true, + user: { settings: {} }, + workspace: { settings: {} }, + setValue: vi.fn(), + } as unknown as LoadedSettings; const { rerender } = render( - - - , + + + + + , ); expect(mockedSelect.mock.calls[0][0].initialIndex).toBe(0); @@ -215,9 +325,11 @@ describe('', () => { } as unknown as Config; rerender( - - - , + + + + + , ); // Should be called at least twice: initial render + re-render after context change diff --git a/packages/cli/src/ui/components/ModelDialog.tsx b/packages/cli/src/ui/components/ModelDialog.tsx index 55b3300bf..b5d39cc46 100644 --- a/packages/cli/src/ui/components/ModelDialog.tsx +++ b/packages/cli/src/ui/components/ModelDialog.tsx @@ -5,52 +5,210 @@ */ import type React from 'react'; -import { useCallback, useContext, useMemo } from 'react'; +import { useCallback, useContext, useMemo, useState } from 'react'; import { Box, Text } from 'ink'; import { AuthType, ModelSlashCommandEvent, logModelSlashCommand, + type ContentGeneratorConfig, + type ContentGeneratorConfigSource, + type ContentGeneratorConfigSources, } from '@qwen-code/qwen-code-core'; import { useKeypress } from '../hooks/useKeypress.js'; import { theme } from '../semantic-colors.js'; import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js'; import { ConfigContext } from '../contexts/ConfigContext.js'; +import { UIStateContext } from '../contexts/UIStateContext.js'; +import { useSettings } from '../contexts/SettingsContext.js'; import { getAvailableModelsForAuthType, MAINLINE_CODER, } from '../models/availableModels.js'; +import { getPersistScopeForModelSelection } from '../../config/modelProvidersScope.js'; import { t } from '../../i18n/index.js'; interface ModelDialogProps { onClose: () => void; } +function formatSourceBadge( + source: ContentGeneratorConfigSource | undefined, +): string | undefined { + if (!source) return undefined; + + switch (source.kind) { + case 'cli': + return source.detail ? `CLI ${source.detail}` : 'CLI'; + case 'env': + return source.envKey ? `ENV ${source.envKey}` : 'ENV'; + case 'settings': + return source.settingsPath + ? `Settings ${source.settingsPath}` + : 'Settings'; + case 'modelProviders': { + const suffix = + source.authType && source.modelId + ? `${source.authType}:${source.modelId}` + : source.authType + ? `${source.authType}` + : source.modelId + ? `${source.modelId}` + : ''; + return suffix ? `ModelProviders ${suffix}` : 'ModelProviders'; + } + case 'default': + return source.detail ? `Default ${source.detail}` : 'Default'; + case 'computed': + return source.detail ? `Computed ${source.detail}` : 'Computed'; + case 'programmatic': + return source.detail ? `Programmatic ${source.detail}` : 'Programmatic'; + case 'unknown': + default: + return undefined; + } +} + +function readSourcesFromConfig(config: unknown): ContentGeneratorConfigSources { + if (!config) { + return {}; + } + const maybe = config as { + getContentGeneratorConfigSources?: () => ContentGeneratorConfigSources; + }; + return maybe.getContentGeneratorConfigSources?.() ?? {}; +} + +function maskApiKey(apiKey: string | undefined): string { + if (!apiKey) return '(not set)'; + const trimmed = apiKey.trim(); + if (trimmed.length === 0) return '(not set)'; + if (trimmed.length <= 6) return '***'; + const head = trimmed.slice(0, 3); + const tail = trimmed.slice(-4); + return `${head}…${tail}`; +} + +function persistModelSelection( + settings: ReturnType, + modelId: string, +): void { + const scope = getPersistScopeForModelSelection(settings); + settings.setValue(scope, 'model.name', modelId); +} + +function persistAuthTypeSelection( + settings: ReturnType, + authType: AuthType, +): void { + const scope = getPersistScopeForModelSelection(settings); + settings.setValue(scope, 'security.auth.selectedType', authType); +} + +function ConfigRow({ + label, + value, + badge, +}: { + label: string; + value: React.ReactNode; + badge?: string; +}): React.JSX.Element { + return ( + + + + {label}: + + + {value} + + + {badge ? ( + + + + + + {badge} + + + ) : null} + + ); +} + export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { const config = useContext(ConfigContext); + const uiState = useContext(UIStateContext); + const settings = useSettings(); + + // Local error state for displaying errors within the dialog + const [errorMessage, setErrorMessage] = useState(null); - // Get auth type from config, default to QWEN_OAUTH if not available const authType = config?.getAuthType() ?? AuthType.QWEN_OAUTH; + const effectiveConfig = + (config?.getContentGeneratorConfig?.() as + | ContentGeneratorConfig + | undefined) ?? undefined; + const sources = readSourcesFromConfig(config); - // Get available models based on auth type - const availableModels = useMemo( - () => getAvailableModelsForAuthType(authType), - [authType], - ); + const availableModelEntries = useMemo(() => { + const allAuthTypes = Object.values(AuthType) as AuthType[]; + const modelsByAuthType = allAuthTypes + .map((t) => ({ + authType: t, + models: getAvailableModelsForAuthType(t, config ?? undefined), + })) + .filter((x) => x.models.length > 0); + + // Fixed order: qwen-oauth first, then others in a stable order + const authTypeOrder: AuthType[] = [ + AuthType.QWEN_OAUTH, + AuthType.USE_OPENAI, + AuthType.USE_ANTHROPIC, + AuthType.USE_GEMINI, + AuthType.USE_VERTEX_AI, + ]; + + // Filter to only include authTypes that have models + const availableAuthTypes = new Set(modelsByAuthType.map((x) => x.authType)); + const orderedAuthTypes = authTypeOrder.filter((t) => + availableAuthTypes.has(t), + ); + + return orderedAuthTypes.flatMap((t) => { + const models = + modelsByAuthType.find((x) => x.authType === t)?.models ?? []; + return models.map((m) => ({ authType: t, model: m })); + }); + }, [config]); const MODEL_OPTIONS = useMemo( () => - availableModels.map((model) => ({ - value: model.id, - title: model.label, - description: model.description || '', - key: model.id, - })), - [availableModels], + availableModelEntries.map(({ authType: t2, model }) => { + const value = `${t2}::${model.id}`; + const title = ( + + + [{t2}] + + {` ${model.label}`} + + ); + const description = model.description || ''; + return { + value, + title, + description, + key: value, + }; + }), + [availableModelEntries], ); - // Determine the Preferred Model (read once when the dialog opens). - const preferredModel = config?.getModel() || MAINLINE_CODER; + const preferredModelId = config?.getModel() || MAINLINE_CODER; + const preferredKey = `${authType}::${preferredModelId}`; useKeypress( (key) => { @@ -61,25 +219,97 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { { isActive: true }, ); - // Calculate the initial index based on the preferred model. const initialIndex = useMemo( - () => MODEL_OPTIONS.findIndex((option) => option.value === preferredModel), - [MODEL_OPTIONS, preferredModel], + () => MODEL_OPTIONS.findIndex((option) => option.value === preferredKey), + [MODEL_OPTIONS, preferredKey], ); - // Handle selection internally (Autonomous Dialog). const handleSelect = useCallback( - (model: string) => { + async (selected: string) => { + // Clear any previous error + setErrorMessage(null); + + const sep = '::'; + const idx = selected.indexOf(sep); + const selectedAuthType = ( + idx >= 0 ? selected.slice(0, idx) : authType + ) as AuthType; + const modelId = idx >= 0 ? selected.slice(idx + sep.length) : selected; + if (config) { - config.setModel(model); - const event = new ModelSlashCommandEvent(model); + try { + await config.switchModel( + selectedAuthType, + modelId, + selectedAuthType !== authType && + selectedAuthType === AuthType.QWEN_OAUTH + ? { requireCachedCredentials: true } + : undefined, + { + reason: 'user_manual', + context: + selectedAuthType === authType + ? 'Model switched via /model dialog' + : 'AuthType+model switched via /model dialog', + }, + ); + } catch (e) { + const baseErrorMessage = e instanceof Error ? e.message : String(e); + + // Some auth types (notably openai without modelProviders configured) can present + // env-based "raw" model IDs in the list. These are not registry-backed and will + // fail switchModel(). Fall back to setModel() to keep UX functional. + const isNotFound = + baseErrorMessage.includes('not found for authType') || + (baseErrorMessage.includes('Model') && + baseErrorMessage.includes('not found')); + if (!isNotFound) { + setErrorMessage( + `Failed to switch model to '${modelId}'.\n\n${baseErrorMessage}`, + ); + + // Keep the dialog open so the user can choose another model. + return; + } + await config.setModel(modelId, { + reason: 'user_manual', + context: 'Model set via /model dialog (raw)', + }); + } + const event = new ModelSlashCommandEvent(modelId); logModelSlashCommand(config, event); + + const after = config.getContentGeneratorConfig?.() as + | ContentGeneratorConfig + | undefined; + const effectiveAuthType = + after?.authType ?? selectedAuthType ?? authType; + const effectiveModelId = after?.model ?? modelId; + + persistModelSelection(settings, effectiveModelId); + persistAuthTypeSelection(settings, effectiveAuthType); + + const baseUrl = after?.baseUrl ?? '(default)'; + const maskedKey = maskApiKey(after?.apiKey); + uiState?.historyManager.addItem( + { + type: 'info', + text: + `authType: ${effectiveAuthType}\n` + + `Using model: ${effectiveModelId}\n` + + `Base URL: ${baseUrl}\n` + + `API key: ${maskedKey}`, + }, + Date.now(), + ); } onClose(); }, - [config, onClose], + [authType, config, onClose, settings, uiState, setErrorMessage], ); + const hasModels = MODEL_OPTIONS.length > 0; + return ( {t('Select Model')} - - + + + + {t('Current (effective) configuration')} + + + + + + {authType !== AuthType.QWEN_OAUTH && ( + <> + + + + )} + + {effectiveConfig?.samplingParams ? ( + + ) : null} + + {effectiveConfig?.timeout !== undefined ? ( + + ) : null} + + + {!hasModels ? ( + + + {t( + 'No models available for the current authentication type ({{authType}}).', + { + authType, + }, + )} + + + + {t( + 'Please configure models in settings.modelProviders or use environment variables.', + )} + + + + ) : ( + + + + )} + + {errorMessage && ( + + + ✕ {errorMessage} + + + )} + {t('(Press Esc to close)')} diff --git a/packages/cli/src/ui/components/shared/DescriptiveRadioButtonSelect.tsx b/packages/cli/src/ui/components/shared/DescriptiveRadioButtonSelect.tsx index 3cc563283..89bf4c03b 100644 --- a/packages/cli/src/ui/components/shared/DescriptiveRadioButtonSelect.tsx +++ b/packages/cli/src/ui/components/shared/DescriptiveRadioButtonSelect.tsx @@ -11,7 +11,7 @@ import { BaseSelectionList } from './BaseSelectionList.js'; import type { SelectionListItem } from '../../hooks/useSelectionList.js'; export interface DescriptiveRadioSelectItem extends SelectionListItem { - title: string; + title: React.ReactNode; description: string; } diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index e70ea0538..561c98ed6 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -912,7 +912,7 @@ export const useGeminiStream = ( // Reset quota error flag when starting a new query (not a continuation) if (!options?.isContinuation) { setModelSwitchedFromQuotaError(false); - config.setQuotaErrorOccurred(false); + // No quota-error / fallback routing mechanism currently; keep state minimal. } abortControllerRef.current = new AbortController(); diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index d7b2b8109..961b52b24 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -62,7 +62,7 @@ const mockConfig = { getAllowedTools: vi.fn(() => []), getContentGeneratorConfig: () => ({ model: 'test-model', - authType: 'gemini-api-key', + authType: 'gemini', }), getUseSmartEdit: () => false, getUseModelRouter: () => false, diff --git a/packages/cli/src/ui/models/availableModels.test.ts b/packages/cli/src/ui/models/availableModels.test.ts new file mode 100644 index 000000000..feac835c6 --- /dev/null +++ b/packages/cli/src/ui/models/availableModels.test.ts @@ -0,0 +1,205 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + getAvailableModelsForAuthType, + getFilteredQwenModels, + getOpenAIAvailableModelFromEnv, + isVisionModel, + getDefaultVisionModel, + AVAILABLE_MODELS_QWEN, + MAINLINE_VLM, + MAINLINE_CODER, +} from './availableModels.js'; +import { AuthType, type Config } from '@qwen-code/qwen-code-core'; + +describe('availableModels', () => { + describe('AVAILABLE_MODELS_QWEN', () => { + it('should include coder model', () => { + const coderModel = AVAILABLE_MODELS_QWEN.find( + (m) => m.id === MAINLINE_CODER, + ); + expect(coderModel).toBeDefined(); + expect(coderModel?.isVision).toBeFalsy(); + }); + + it('should include vision model', () => { + const visionModel = AVAILABLE_MODELS_QWEN.find( + (m) => m.id === MAINLINE_VLM, + ); + expect(visionModel).toBeDefined(); + expect(visionModel?.isVision).toBe(true); + }); + }); + + describe('getFilteredQwenModels', () => { + it('should return all models when vision preview is enabled', () => { + const models = getFilteredQwenModels(true); + expect(models.length).toBe(AVAILABLE_MODELS_QWEN.length); + }); + + it('should filter out vision models when preview is disabled', () => { + const models = getFilteredQwenModels(false); + expect(models.every((m) => !m.isVision)).toBe(true); + }); + }); + + describe('getOpenAIAvailableModelFromEnv', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it('should return null when OPENAI_MODEL is not set', () => { + delete process.env['OPENAI_MODEL']; + expect(getOpenAIAvailableModelFromEnv()).toBeNull(); + }); + + it('should return model from OPENAI_MODEL env var', () => { + process.env['OPENAI_MODEL'] = 'gpt-4-turbo'; + const model = getOpenAIAvailableModelFromEnv(); + expect(model?.id).toBe('gpt-4-turbo'); + expect(model?.label).toBe('gpt-4-turbo'); + }); + + it('should trim whitespace from env var', () => { + process.env['OPENAI_MODEL'] = ' gpt-4 '; + const model = getOpenAIAvailableModelFromEnv(); + expect(model?.id).toBe('gpt-4'); + }); + }); + + describe('getAvailableModelsForAuthType', () => { + const originalEnv = process.env; + + beforeEach(() => { + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it('should return hard-coded qwen models for qwen-oauth', () => { + const models = getAvailableModelsForAuthType(AuthType.QWEN_OAUTH); + expect(models).toEqual(AVAILABLE_MODELS_QWEN); + }); + + it('should return hard-coded qwen models even when config is provided', () => { + const mockConfig = { + getAvailableModels: vi + .fn() + .mockReturnValue([ + { id: 'custom', label: 'Custom', authType: AuthType.QWEN_OAUTH }, + ]), + } as unknown as Config; + + const models = getAvailableModelsForAuthType( + AuthType.QWEN_OAUTH, + mockConfig, + ); + expect(models).toEqual(AVAILABLE_MODELS_QWEN); + }); + + it('should use config.getAvailableModels for openai authType when available', () => { + const mockModels = [ + { + id: 'gpt-4', + label: 'GPT-4', + description: 'Test', + authType: AuthType.USE_OPENAI, + isVision: false, + }, + ]; + const getAvailableModelsForAuthType = vi.fn().mockReturnValue(mockModels); + const mockConfigWithMethod = { + // Prefer the newer API when available. + getAvailableModelsForAuthType, + }; + + const models = getAvailableModelsForAuthType( + AuthType.USE_OPENAI, + mockConfigWithMethod as unknown as Config, + ); + + expect(getAvailableModelsForAuthType).toHaveBeenCalled(); + expect(models[0].id).toBe('gpt-4'); + }); + + it('should fallback to env var for openai when config returns empty', () => { + process.env['OPENAI_MODEL'] = 'fallback-model'; + const mockConfig = { + getAvailableModelsForAuthType: vi.fn().mockReturnValue([]), + } as unknown as Config; + + const models = getAvailableModelsForAuthType( + AuthType.USE_OPENAI, + mockConfig, + ); + + expect(models).toEqual([]); + }); + + it('should fallback to env var for openai when config throws', () => { + process.env['OPENAI_MODEL'] = 'fallback-model'; + const mockConfig = { + getAvailableModelsForAuthType: vi.fn().mockImplementation(() => { + throw new Error('Registry not initialized'); + }), + } as unknown as Config; + + const models = getAvailableModelsForAuthType( + AuthType.USE_OPENAI, + mockConfig, + ); + + expect(models).toEqual([]); + }); + + it('should return env model for openai without config', () => { + process.env['OPENAI_MODEL'] = 'gpt-4-turbo'; + const models = getAvailableModelsForAuthType(AuthType.USE_OPENAI); + expect(models[0].id).toBe('gpt-4-turbo'); + }); + + it('should return empty array for openai without config or env', () => { + delete process.env['OPENAI_MODEL']; + const models = getAvailableModelsForAuthType(AuthType.USE_OPENAI); + expect(models).toEqual([]); + }); + + it('should return empty array for other auth types', () => { + const models = getAvailableModelsForAuthType(AuthType.USE_GEMINI); + expect(models).toEqual([]); + }); + }); + + describe('isVisionModel', () => { + it('should return true for vision model', () => { + expect(isVisionModel(MAINLINE_VLM)).toBe(true); + }); + + it('should return false for non-vision model', () => { + expect(isVisionModel(MAINLINE_CODER)).toBe(false); + }); + + it('should return false for unknown model', () => { + expect(isVisionModel('unknown-model')).toBe(false); + }); + }); + + describe('getDefaultVisionModel', () => { + it('should return the vision model ID', () => { + expect(getDefaultVisionModel()).toBe(MAINLINE_VLM); + }); + }); +}); diff --git a/packages/cli/src/ui/models/availableModels.ts b/packages/cli/src/ui/models/availableModels.ts index d9c9eb725..1cff984c8 100644 --- a/packages/cli/src/ui/models/availableModels.ts +++ b/packages/cli/src/ui/models/availableModels.ts @@ -4,7 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AuthType, DEFAULT_QWEN_MODEL } from '@qwen-code/qwen-code-core'; +import { + AuthType, + DEFAULT_QWEN_MODEL, + type Config, + type AvailableModel as CoreAvailableModel, +} from '@qwen-code/qwen-code-core'; import { t } from '../../i18n/index.js'; export type AvailableModel = { @@ -57,20 +62,78 @@ export function getFilteredQwenModels( */ export function getOpenAIAvailableModelFromEnv(): AvailableModel | null { const id = process.env['OPENAI_MODEL']?.trim(); - return id ? { id, label: id } : null; + return id + ? { + id, + label: id, + get description() { + return t('Configured via OPENAI_MODEL environment variable'); + }, + } + : null; } export function getAnthropicAvailableModelFromEnv(): AvailableModel | null { const id = process.env['ANTHROPIC_MODEL']?.trim(); - return id ? { id, label: id } : null; + return id + ? { + id, + label: id, + get description() { + return t('Configured via ANTHROPIC_MODEL environment variable'); + }, + } + : null; } +/** + * Convert core AvailableModel to CLI AvailableModel format + */ +function convertCoreModelToCliModel( + coreModel: CoreAvailableModel, +): AvailableModel { + return { + id: coreModel.id, + label: coreModel.label, + description: coreModel.description, + isVision: coreModel.isVision ?? coreModel.capabilities?.vision ?? false, + }; +} + +/** + * Get available models for the given authType. + * + * If a Config object is provided, uses config.getAvailableModelsForAuthType(). + * For qwen-oauth, always returns the hard-coded models. + * Falls back to environment variables only when no config is provided. + */ export function getAvailableModelsForAuthType( authType: AuthType, + config?: Config, ): AvailableModel[] { + // For qwen-oauth, always use hard-coded models, this aligns with the API gateway. + if (authType === AuthType.QWEN_OAUTH) { + return AVAILABLE_MODELS_QWEN; + } + + // Use config's model registry when available + if (config) { + try { + const models = config.getAvailableModelsForAuthType(authType); + if (models.length > 0) { + return models.map(convertCoreModelToCliModel); + } + } catch { + // If config throws (e.g., not initialized), return empty array + } + // When a Config object is provided, we intentionally do NOT fall back to env-based + // "raw" models. These may reflect the currently effective config but should not be + // presented as selectable options in /model. + return []; + } + + // Fall back to environment variables for specific auth types (no config provided) switch (authType) { - case AuthType.QWEN_OAUTH: - return AVAILABLE_MODELS_QWEN; case AuthType.USE_OPENAI: { const openAIModel = getOpenAIAvailableModelFromEnv(); return openAIModel ? [openAIModel] : []; @@ -80,13 +143,10 @@ export function getAvailableModelsForAuthType( return anthropicModel ? [anthropicModel] : []; } default: - // For other auth types, return empty array for now - // This can be expanded later according to the design doc return []; } } -/** /** * Hard code the default vision model as a string literal, * until our coding model supports multimodal. diff --git a/packages/cli/src/utils/modelConfigUtils.ts b/packages/cli/src/utils/modelConfigUtils.ts new file mode 100644 index 000000000..cb710692a --- /dev/null +++ b/packages/cli/src/utils/modelConfigUtils.ts @@ -0,0 +1,112 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AuthType, + type ContentGeneratorConfig, + type ContentGeneratorConfigSources, + resolveModelConfig, + type ModelConfigSourcesInput, +} from '@qwen-code/qwen-code-core'; +import type { Settings } from '../config/settings.js'; + +export interface CliGenerationConfigInputs { + argv: { + model?: string | undefined; + openaiApiKey?: string | undefined; + openaiBaseUrl?: string | undefined; + openaiLogging?: boolean | undefined; + openaiLoggingDir?: string | undefined; + }; + settings: Settings; + selectedAuthType: AuthType | undefined; + /** + * Injectable env for testability. Defaults to process.env at callsites. + */ + env?: Record; +} + +export interface ResolvedCliGenerationConfig { + /** The resolved model id (may be empty string if not resolvable at CLI layer) */ + model: string; + /** API key for OpenAI-compatible auth */ + apiKey: string; + /** Base URL for OpenAI-compatible auth */ + baseUrl: string; + /** The full generation config to pass to core Config */ + generationConfig: Partial; + /** Source attribution for each resolved field */ + sources: ContentGeneratorConfigSources; +} + +/** + * Unified resolver for CLI generation config. + * + * Precedence (for OpenAI auth): + * - model: argv.model > OPENAI_MODEL > QWEN_MODEL > settings.model.name + * - apiKey: argv.openaiApiKey > OPENAI_API_KEY > settings.security.auth.apiKey + * - baseUrl: argv.openaiBaseUrl > OPENAI_BASE_URL > settings.security.auth.baseUrl + * + * For non-OpenAI auth, only argv.model override is respected at CLI layer. + */ +export function resolveCliGenerationConfig( + inputs: CliGenerationConfigInputs, +): ResolvedCliGenerationConfig { + const { argv, settings, selectedAuthType } = inputs; + const env = inputs.env ?? (process.env as Record); + + const authType = selectedAuthType ?? AuthType.QWEN_OAUTH; + + const configSources: ModelConfigSourcesInput = { + authType, + cli: { + model: argv.model, + apiKey: argv.openaiApiKey, + baseUrl: argv.openaiBaseUrl, + }, + settings: { + model: settings.model?.name, + apiKey: settings.security?.auth?.apiKey, + baseUrl: settings.security?.auth?.baseUrl, + generationConfig: settings.model?.generationConfig as + | Partial + | undefined, + }, + env, + }; + + const resolved = resolveModelConfig(configSources); + + // Log warnings if any + for (const warning of resolved.warnings) { + console.warn(`[modelProviderUtils] ${warning}`); + } + + // Resolve OpenAI logging config (CLI-specific, not part of core resolver) + const enableOpenAILogging = + (typeof argv.openaiLogging === 'undefined' + ? settings.model?.enableOpenAILogging + : argv.openaiLogging) ?? false; + + const openAILoggingDir = + argv.openaiLoggingDir || settings.model?.openAILoggingDir; + + // Build the full generation config + // Note: we merge the resolved config with logging settings + const generationConfig: Partial = { + ...resolved.config, + enableOpenAILogging, + openAILoggingDir, + }; + + return { + model: resolved.config.model || '', + apiKey: resolved.config.apiKey || '', + baseUrl: resolved.config.baseUrl || '', + generationConfig, + sources: resolved.sources, + }; +} diff --git a/packages/cli/src/utils/modelProviderUtils.ts b/packages/cli/src/utils/modelProviderUtils.ts deleted file mode 100644 index 473499771..000000000 --- a/packages/cli/src/utils/modelProviderUtils.ts +++ /dev/null @@ -1,142 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - AuthType, - type ContentGeneratorConfig, - type ContentGeneratorConfigSource, - type ContentGeneratorConfigSources, - type ModelProvidersConfig, - type ProviderModelConfig as ModelConfig, -} from '@qwen-code/qwen-code-core'; -import type { Settings } from '../config/settings.js'; - -export interface GenerationConfigSourceInputs { - argv: { - model?: string | undefined; - openaiApiKey?: string | undefined; - openaiBaseUrl?: string | undefined; - }; - settings: Settings; - selectedAuthType: AuthType | undefined; - /** - * Injectable env for testability. Defaults to process.env at callsites. - */ - env?: Record; -} - -/** - * Get models configuration from settings, grouped by authType. - * Returns the models config from the merged settings without mutating files. - */ -export function getModelProvidersConfigFromSettings( - settings: Settings, -): ModelProvidersConfig { - return (settings.modelProviders as ModelProvidersConfig) || {}; -} - -/** - * Get models for a specific authType from settings. - */ -export function getModelsForAuthType( - settings: Settings, - authType: AuthType, -): ModelConfig[] { - const modelProvidersConfig = getModelProvidersConfigFromSettings(settings); - return modelProvidersConfig[authType] || []; -} - -/** - * Best-effort attribution for the seed generationConfig fields. - * - * NOTE: - * - This does not attempt to distinguish user vs workspace settings; it reflects merged settings. - * - This should stay consistent with the actual precedence used to compute the corresponding values. - */ -export function buildGenerationConfigSources( - inputs: GenerationConfigSourceInputs, -): ContentGeneratorConfigSources { - const { argv, settings, selectedAuthType } = inputs; - const env = inputs.env ?? (process.env as Record); - - const sources: ContentGeneratorConfigSources = {}; - - const setSource = (path: string, source: ContentGeneratorConfigSource) => { - sources[path] = source; - }; - - // Model/apiKey/baseUrl attribution mirrors current CLI precedence: - // - model: argv.model > (OPENAI_MODEL|QWEN_MODEL|settings.model.name) only for OpenAI auth - // - apiKey/baseUrl: only meaningful for OpenAI auth in current CLI wiring - if (selectedAuthType === AuthType.USE_OPENAI) { - if (argv.model) { - setSource('model', { kind: 'cli', detail: '--model' }); - } else if (env['OPENAI_MODEL']) { - setSource('model', { kind: 'env', envKey: 'OPENAI_MODEL' }); - } else if (env['QWEN_MODEL']) { - setSource('model', { kind: 'env', envKey: 'QWEN_MODEL' }); - } else if (settings.model?.name) { - setSource('model', { kind: 'settings', settingsPath: 'model.name' }); - } - - if (argv.openaiApiKey) { - setSource('apiKey', { kind: 'cli', detail: '--openaiApiKey' }); - } else if (env['OPENAI_API_KEY']) { - setSource('apiKey', { kind: 'env', envKey: 'OPENAI_API_KEY' }); - } else if (settings.security?.auth?.apiKey) { - setSource('apiKey', { - kind: 'settings', - settingsPath: 'security.auth.apiKey', - }); - } - - if (argv.openaiBaseUrl) { - setSource('baseUrl', { kind: 'cli', detail: '--openaiBaseUrl' }); - } else if (env['OPENAI_BASE_URL']) { - setSource('baseUrl', { kind: 'env', envKey: 'OPENAI_BASE_URL' }); - } else if (settings.security?.auth?.baseUrl) { - setSource('baseUrl', { - kind: 'settings', - settingsPath: 'security.auth.baseUrl', - }); - } - } else if (argv.model) { - // For non-openai auth types, the CLI only wires through an explicit raw model override. - setSource('model', { kind: 'cli', detail: '--model' }); - } - - const mergedGenerationConfig = settings.model?.generationConfig as - | Partial - | undefined; - if (mergedGenerationConfig) { - setSource('generationConfig', { - kind: 'settings', - settingsPath: 'model.generationConfig', - }); - // We also map the known top-level fields used by core. - if (mergedGenerationConfig.samplingParams) { - setSource('samplingParams', { - kind: 'settings', - settingsPath: 'model.generationConfig.samplingParams', - }); - } - for (const k of [ - 'timeout', - 'maxRetries', - 'disableCacheControl', - 'schemaCompliance', - ] as const) { - if (mergedGenerationConfig[k] !== undefined) { - setSource(k, { - kind: 'settings', - settingsPath: `model.generationConfig.${k}`, - }); - } - } - } - - return sources; -} diff --git a/packages/core/index.ts b/packages/core/index.ts index 3227199e4..aab675a18 100644 --- a/packages/core/index.ts +++ b/packages/core/index.ts @@ -8,12 +8,8 @@ export * from './src/index.js'; export { Storage } from './src/config/storage.js'; export { DEFAULT_QWEN_MODEL, + DEFAULT_QWEN_FLASH_MODEL, DEFAULT_QWEN_EMBEDDING_MODEL, - DEFAULT_GEMINI_MODEL, - DEFAULT_GEMINI_MODEL_AUTO, - DEFAULT_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_FLASH_LITE_MODEL, - DEFAULT_GEMINI_EMBEDDING_MODEL, } from './src/config/models.js'; export { serializeTerminalToObject, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 1b163b9a6..449add116 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -15,10 +15,16 @@ import { DEFAULT_OTLP_ENDPOINT, QwenLogger, } from '../telemetry/index.js'; -import type { ContentGeneratorConfig } from '../core/contentGenerator.js'; +import type { + ContentGenerator, + ContentGeneratorConfig, +} from '../core/contentGenerator.js'; +import { DEFAULT_DASHSCOPE_BASE_URL } from '../core/openaiContentGenerator/constants.js'; import { AuthType, + createContentGenerator, createContentGeneratorConfig, + resolveContentGeneratorConfigWithSources, } from '../core/contentGenerator.js'; import { GeminiClient } from '../core/client.js'; import { GitService } from '../services/gitService.js'; @@ -208,6 +214,19 @@ describe('Server Config (config.ts)', () => { vi.spyOn(QwenLogger.prototype, 'logStartSessionEvent').mockImplementation( async () => undefined, ); + + // Setup default mock for resolveContentGeneratorConfigWithSources + vi.mocked(resolveContentGeneratorConfigWithSources).mockImplementation( + (_config, authType, generationConfig) => ({ + config: { + ...generationConfig, + authType, + model: generationConfig?.model || MODEL, + apiKey: 'test-key', + } as ContentGeneratorConfig, + sources: {}, + }), + ); }); describe('initialize', () => { @@ -255,31 +274,28 @@ describe('Server Config (config.ts)', () => { const mockContentConfig = { apiKey: 'test-key', model: 'qwen3-coder-plus', + authType, }; - vi.mocked(createContentGeneratorConfig).mockReturnValue( - mockContentConfig, - ); - - // Set fallback mode to true to ensure it gets reset - config.setFallbackMode(true); - expect(config.isInFallbackMode()).toBe(true); + vi.mocked(resolveContentGeneratorConfigWithSources).mockReturnValue({ + config: mockContentConfig as ContentGeneratorConfig, + sources: {}, + }); await config.refreshAuth(authType); - expect(createContentGeneratorConfig).toHaveBeenCalledWith( + expect(resolveContentGeneratorConfigWithSources).toHaveBeenCalledWith( config, authType, - { + expect.objectContaining({ model: MODEL, - baseUrl: undefined, - }, + }), + expect.anything(), + expect.anything(), ); // Verify that contentGeneratorConfig is updated expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); expect(GeminiClient).toHaveBeenCalledWith(config); - // Verify that fallback mode is reset - expect(config.isInFallbackMode()).toBe(false); }); it('should not strip thoughts when switching from Vertex to GenAI', async () => { @@ -300,6 +316,129 @@ describe('Server Config (config.ts)', () => { }); }); + describe('model switching optimization (QWEN_OAUTH)', () => { + it('should switch qwen-oauth model in-place without refreshing auth when safe', async () => { + const config = new Config(baseParams); + + const mockContentConfig: ContentGeneratorConfig = { + authType: AuthType.QWEN_OAUTH, + model: 'coder-model', + apiKey: 'QWEN_OAUTH_DYNAMIC_TOKEN', + baseUrl: DEFAULT_DASHSCOPE_BASE_URL, + timeout: 60000, + maxRetries: 3, + } as ContentGeneratorConfig; + + vi.mocked(resolveContentGeneratorConfigWithSources).mockImplementation( + (_config, authType, generationConfig) => ({ + config: { + ...mockContentConfig, + authType, + model: generationConfig?.model ?? mockContentConfig.model, + } as ContentGeneratorConfig, + sources: {}, + }), + ); + vi.mocked(createContentGenerator).mockResolvedValue({ + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + } as unknown as ContentGenerator); + + // Establish initial qwen-oauth content generator config/content generator. + await config.refreshAuth(AuthType.QWEN_OAUTH); + + // Spy after initial refresh to ensure model switch does not re-trigger refreshAuth. + const refreshSpy = vi.spyOn(config, 'refreshAuth'); + + await config.switchModel(AuthType.QWEN_OAUTH, 'vision-model'); + + expect(config.getModel()).toBe('vision-model'); + expect(refreshSpy).not.toHaveBeenCalled(); + // Called once during initial refreshAuth + once during handleModelChange diffing. + expect( + vi.mocked(resolveContentGeneratorConfigWithSources), + ).toHaveBeenCalledTimes(2); + expect(vi.mocked(createContentGenerator)).toHaveBeenCalledTimes(1); + }); + }); + + describe('model switching with different credentials (OpenAI)', () => { + it('should refresh auth when switching to model with different envKey', async () => { + // This test verifies the fix for switching between modelProvider models + // with different envKeys (e.g., deepseek-chat with DEEPSEEK_API_KEY) + const configWithModelProviders = new Config({ + ...baseParams, + authType: AuthType.USE_OPENAI, + modelProvidersConfig: { + openai: [ + { + id: 'model-a', + name: 'Model A', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_A', + }, + { + id: 'model-b', + name: 'Model B', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_B', + }, + ], + }, + }); + + const mockContentConfigA: ContentGeneratorConfig = { + authType: AuthType.USE_OPENAI, + model: 'model-a', + apiKey: 'key-a', + baseUrl: 'https://api.example.com/v1', + } as ContentGeneratorConfig; + + const mockContentConfigB: ContentGeneratorConfig = { + authType: AuthType.USE_OPENAI, + model: 'model-b', + apiKey: 'key-b', + baseUrl: 'https://api.example.com/v1', + } as ContentGeneratorConfig; + + vi.mocked(resolveContentGeneratorConfigWithSources).mockImplementation( + (_config, _authType, generationConfig) => { + const model = generationConfig?.model; + return { + config: + model === 'model-b' ? mockContentConfigB : mockContentConfigA, + sources: {}, + }; + }, + ); + + vi.mocked(createContentGenerator).mockResolvedValue({ + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + } as unknown as ContentGenerator); + + // Initialize with model-a + await configWithModelProviders.refreshAuth(AuthType.USE_OPENAI); + + // Spy on refreshAuth to verify it's called when switching to model-b + const refreshSpy = vi.spyOn(configWithModelProviders, 'refreshAuth'); + + // Switch to model-b (different envKey) + await configWithModelProviders.switchModel( + AuthType.USE_OPENAI, + 'model-b', + ); + + // Should trigger full refresh because envKey changed + expect(refreshSpy).toHaveBeenCalledWith(AuthType.USE_OPENAI); + expect(configWithModelProviders.getModel()).toBe('model-b'); + }); + }); + it('Config constructor should store userMemory correctly', () => { const config = new Config(baseParams); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 34dbb4649..eae1dd44b 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -16,9 +16,8 @@ import { ProxyAgent, setGlobalDispatcher } from 'undici'; import type { ContentGenerator, ContentGeneratorConfig, - AuthType, } from '../core/contentGenerator.js'; -import type { FallbackModelHandler } from '../fallback/types.js'; +import type { ContentGeneratorConfigSources } from '../core/contentGenerator.js'; import type { MCPOAuthConfig } from '../mcp/oauth-provider.js'; import type { ShellExecutionConfig } from '../services/shellExecutionService.js'; import type { AnyToolInvocation } from '../tools/tools.js'; @@ -27,8 +26,9 @@ import type { AnyToolInvocation } from '../tools/tools.js'; import { BaseLlmClient } from '../core/baseLlmClient.js'; import { GeminiClient } from '../core/client.js'; import { + AuthType, createContentGenerator, - createContentGeneratorConfig, + resolveContentGeneratorConfigWithSources, } from '../core/contentGenerator.js'; import { tokenLimit } from '../core/tokenLimits.js'; @@ -94,7 +94,7 @@ import { DEFAULT_FILE_FILTERING_OPTIONS, DEFAULT_MEMORY_FILE_FILTERING_OPTIONS, } from './constants.js'; -import { DEFAULT_QWEN_EMBEDDING_MODEL, DEFAULT_QWEN_MODEL } from './models.js'; +import { DEFAULT_QWEN_EMBEDDING_MODEL } from './models.js'; import { Storage } from './storage.js'; import { ChatRecordingService } from '../services/chatRecordingService.js'; import { @@ -103,6 +103,12 @@ import { } from '../services/sessionService.js'; import { randomUUID } from 'node:crypto'; +import { + ModelsConfig, + type ModelProvidersConfig, + type AvailableModel, +} from '../models/index.js'; + // Re-export types export type { AnyToolInvocation, FileFilteringOptions, MCPOAuthConfig }; export { @@ -318,6 +324,11 @@ export interface ConfigParameters { ideMode?: boolean; authType?: AuthType; generationConfig?: Partial; + /** + * Optional source map for generationConfig fields (e.g. CLI/env/settings attribution). + * This is used to produce per-field source badges in the UI. + */ + generationConfigSources?: ContentGeneratorConfigSources; cliVersion?: string; loadMemoryFromIncludeDirectories?: boolean; chatRecording?: boolean; @@ -353,6 +364,8 @@ export interface ConfigParameters { sdkMode?: boolean; sessionSubagents?: SubagentConfig[]; channel?: string; + /** Model providers configuration grouped by authType */ + modelProvidersConfig?: ModelProvidersConfig; } function normalizeConfigOutputFormat( @@ -394,9 +407,12 @@ export class Config { private skillManager!: SkillManager; private fileSystemService: FileSystemService; private contentGeneratorConfig!: ContentGeneratorConfig; + private contentGeneratorConfigSources: ContentGeneratorConfigSources = {}; private contentGenerator!: ContentGenerator; - private _generationConfig: Partial; private readonly embeddingModel: string; + + private _modelsConfig!: ModelsConfig; + private readonly modelProvidersConfig?: ModelProvidersConfig; private readonly sandbox: SandboxConfig | undefined; private readonly targetDir: string; private workspaceContext: WorkspaceContext; @@ -445,7 +461,6 @@ export class Config { private readonly folderTrust: boolean; private ideMode: boolean; - private inFallbackMode = false; private readonly maxSessionTurns: number; private readonly sessionTokenLimit: number; private readonly listExtensions: boolean; @@ -454,8 +469,6 @@ export class Config { name: string; extensionName: string; }>; - fallbackModelHandler?: FallbackModelHandler; - private quotaErrorOccurred: boolean = false; private readonly summarizeToolOutput: | Record | undefined; @@ -570,13 +583,7 @@ export class Config { this.folderTrustFeature = params.folderTrustFeature ?? false; this.folderTrust = params.folderTrust ?? false; this.ideMode = params.ideMode ?? false; - this._generationConfig = { - model: params.model, - ...(params.generationConfig || {}), - baseUrl: params.generationConfig?.baseUrl, - }; - this.contentGeneratorConfig = this - ._generationConfig as ContentGeneratorConfig; + this.modelProvidersConfig = params.modelProvidersConfig; this.cliVersion = params.cliVersion; this.chatRecordingEnabled = params.chatRecording ?? true; @@ -619,6 +626,23 @@ export class Config { setGeminiMdFilename(params.contextFileName); } + // Create ModelsConfig for centralized model management + // Prefer params.authType over generationConfig.authType because: + // - params.authType preserves undefined (user hasn't selected yet) + // - generationConfig.authType may have a default value from resolvers + this._modelsConfig = new ModelsConfig({ + initialAuthType: params.authType ?? params.generationConfig?.authType, + initialModelId: params.model, + modelProvidersConfig: this.modelProvidersConfig, + generationConfig: { + model: params.model, + ...(params.generationConfig || {}), + baseUrl: params.generationConfig?.baseUrl, + }, + generationConfigSources: params.generationConfigSources, + onModelChange: this.handleModelChange.bind(this), + }); + if (this.telemetrySettings.enabled) { initializeTelemetry(this); } @@ -669,45 +693,61 @@ export class Config { return this.contentGenerator; } + /** + * Get the ModelsConfig instance for model-related operations. + * External code (e.g., CLI) can use this to access model configuration. + */ + get modelsConfig(): ModelsConfig { + return this._modelsConfig; + } + /** * Updates the credentials in the generation config. - * This is needed when credentials are set after Config construction. + * Exclusive for `OpenAIKeyPrompt` to update credentials via `/auth` + * Delegates to ModelsConfig. */ updateCredentials(credentials: { apiKey?: string; baseUrl?: string; model?: string; }): void { - if (credentials.apiKey) { - this._generationConfig.apiKey = credentials.apiKey; - } - if (credentials.baseUrl) { - this._generationConfig.baseUrl = credentials.baseUrl; - } - if (credentials.model) { - this._generationConfig.model = credentials.model; - } + this._modelsConfig.updateCredentials(credentials); } + /** + * Refresh authentication and rebuild ContentGenerator. + */ async refreshAuth(authMethod: AuthType, isInitialAuth?: boolean) { - const newContentGeneratorConfig = createContentGeneratorConfig( + // Sync modelsConfig state for this auth refresh + const modelId = this._modelsConfig.getModel(); + this._modelsConfig.syncAfterAuthRefresh(authMethod, modelId); + + // Check and consume cached credentials flag + const requireCached = + this._modelsConfig.consumeRequireCachedCredentialsFlag(); + + const { config, sources } = resolveContentGeneratorConfigWithSources( this, authMethod, - this._generationConfig, + this._modelsConfig.getGenerationConfig(), + this._modelsConfig.getGenerationConfigSources(), + { + strictModelProvider: + this._modelsConfig.isStrictModelProviderSelection(), + }, ); + const newContentGeneratorConfig = config; this.contentGenerator = await createContentGenerator( newContentGeneratorConfig, this, - isInitialAuth, + requireCached ? true : isInitialAuth, ); // Only assign to instance properties after successful initialization this.contentGeneratorConfig = newContentGeneratorConfig; + this.contentGeneratorConfigSources = sources; // Initialize BaseLlmClient now that the ContentGenerator is available this.baseLlmClient = new BaseLlmClient(this.contentGenerator, this); - - // Reset the session flag since we're explicitly changing auth and using default model - this.inFallbackMode = false; } /** @@ -767,31 +807,125 @@ export class Config { return this.contentGeneratorConfig; } - getModel(): string { - return this.contentGeneratorConfig?.model || DEFAULT_QWEN_MODEL; + getContentGeneratorConfigSources(): ContentGeneratorConfigSources { + // If contentGeneratorConfigSources is empty (before initializeAuth), + // get sources from ModelsConfig + if ( + Object.keys(this.contentGeneratorConfigSources).length === 0 && + this._modelsConfig + ) { + return this._modelsConfig.getGenerationConfigSources(); + } + return this.contentGeneratorConfigSources; } + getModel(): string { + return this.contentGeneratorConfig?.model || this._modelsConfig.getModel(); + } + + /** + * Set model programmatically (e.g., VLM auto-switch, fallback). + * Delegates to ModelsConfig. + */ async setModel( newModel: string, - _metadata?: { reason?: string; context?: string }, + metadata?: { reason?: string; context?: string }, ): Promise { + await this._modelsConfig.setModel(newModel, metadata); + // Also update contentGeneratorConfig for hot-update compatibility if (this.contentGeneratorConfig) { this.contentGeneratorConfig.model = newModel; } - // TODO: Log _metadata for telemetry if needed - // This _metadata can be used for tracking model switches (reason, context) } - isInFallbackMode(): boolean { - return this.inFallbackMode; + /** + * Handle model change from ModelsConfig. + * This updates the content generator config with the new model settings. + */ + private async handleModelChange( + authType: AuthType, + requiresRefresh: boolean, + ): Promise { + if (!this.contentGeneratorConfig) { + return; + } + + // Hot update path: only supported for qwen-oauth. + // For other auth types we always refresh to recreate the ContentGenerator. + // + // Rationale: + // - Non-qwen providers may need to re-validate credentials / baseUrl / envKey. + // - ModelsConfig.applyResolvedModelDefaults can clear or change credentials sources. + // - Refresh keeps runtime behavior consistent and centralized. + if (authType === AuthType.QWEN_OAUTH && !requiresRefresh) { + const { config, sources } = resolveContentGeneratorConfigWithSources( + this, + authType, + this._modelsConfig.getGenerationConfig(), + this._modelsConfig.getGenerationConfigSources(), + { + strictModelProvider: + this._modelsConfig.isStrictModelProviderSelection(), + }, + ); + + // Hot-update fields (qwen-oauth models share the same auth + client). + this.contentGeneratorConfig.model = config.model; + this.contentGeneratorConfig.samplingParams = config.samplingParams; + this.contentGeneratorConfig.disableCacheControl = + config.disableCacheControl; + + if ('model' in sources) { + this.contentGeneratorConfigSources['model'] = sources['model']; + } + if ('samplingParams' in sources) { + this.contentGeneratorConfigSources['samplingParams'] = + sources['samplingParams']; + } + if ('disableCacheControl' in sources) { + this.contentGeneratorConfigSources['disableCacheControl'] = + sources['disableCacheControl']; + } + return; + } + + // Full refresh path + await this.refreshAuth(authType); } - setFallbackMode(active: boolean): void { - this.inFallbackMode = active; + /** + * Get available models for the current authType. + * Delegates to ModelsConfig. + */ + getAvailableModels(): AvailableModel[] { + return this._modelsConfig.getAvailableModels(); } - setFallbackModelHandler(handler: FallbackModelHandler): void { - this.fallbackModelHandler = handler; + /** + * Get available models for a specific authType. + * Delegates to ModelsConfig. + */ + getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] { + return this._modelsConfig.getAvailableModelsForAuthType(authType); + } + + /** + * Switch authType+model via registry-backed selection. + * This triggers a refresh of the ContentGenerator when required (always on authType changes). + * For qwen-oauth model switches that are hot-update safe, this may update in place. + * + * @param authType - Target authentication type + * @param modelId - Target model ID + * @param options - Additional options like requireCachedCredentials + * @param metadata - Metadata for logging/tracking + */ + async switchModel( + authType: AuthType, + modelId: string, + options?: { requireCachedCredentials?: boolean }, + metadata?: { reason?: string; context?: string }, + ): Promise { + await this._modelsConfig.switchModel(authType, modelId, options, metadata); } getMaxSessionTurns(): number { @@ -802,14 +936,6 @@ export class Config { return this.sessionTokenLimit; } - setQuotaErrorOccurred(value: boolean): void { - this.quotaErrorOccurred = value; - } - - getQuotaErrorOccurred(): boolean { - return this.quotaErrorOccurred; - } - getEmbeddingModel(): string { return this.embeddingModel; } diff --git a/packages/core/src/config/flashFallback.test.ts b/packages/core/src/config/flashFallback.test.ts deleted file mode 100644 index 1fff42392..000000000 --- a/packages/core/src/config/flashFallback.test.ts +++ /dev/null @@ -1,100 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, it, expect, beforeEach, vi } from 'vitest'; -import { Config } from './config.js'; -import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL } from './models.js'; -import fs from 'node:fs'; - -vi.mock('node:fs'); - -// Skip this test because we do not have fall back mechanism. -describe.skip('Flash Model Fallback Configuration', () => { - let config: Config; - - beforeEach(() => { - vi.mocked(fs.existsSync).mockReturnValue(true); - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); - config = new Config({ - targetDir: '/test', - debugMode: false, - cwd: '/test', - model: DEFAULT_GEMINI_MODEL, - }); - - // Initialize contentGeneratorConfig for testing - ( - config as unknown as { contentGeneratorConfig: unknown } - ).contentGeneratorConfig = { - model: DEFAULT_GEMINI_MODEL, - authType: 'gemini', - }; - }); - - // These tests do not actually test fallback. isInFallbackMode() only returns true, - // when setFallbackMode is marked as true. This is to decouple setting a model - // with the fallback mechanism. This will be necessary we introduce more - // intelligent model routing. - describe('setModel', () => { - it('should only mark as switched if contentGeneratorConfig exists', async () => { - // Create config without initializing contentGeneratorConfig - const newConfig = new Config({ - targetDir: '/test', - debugMode: false, - cwd: '/test', - model: DEFAULT_GEMINI_MODEL, - }); - - // Should not crash when contentGeneratorConfig is undefined - await newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL); - expect(newConfig.isInFallbackMode()).toBe(false); - }); - }); - - describe('getModel', () => { - it('should return contentGeneratorConfig model if available', async () => { - // Simulate initialized content generator config - await config.setModel(DEFAULT_GEMINI_FLASH_MODEL); - expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL); - }); - - it('should fall back to initial model if contentGeneratorConfig is not available', () => { - // Test with fresh config where contentGeneratorConfig might not be set - const newConfig = new Config({ - targetDir: '/test', - debugMode: false, - cwd: '/test', - model: 'custom-model', - }); - - expect(newConfig.getModel()).toBe('custom-model'); - }); - }); - - describe('isInFallbackMode', () => { - it('should start as false for new session', () => { - expect(config.isInFallbackMode()).toBe(false); - }); - - it('should remain false if no model switch occurs', () => { - // Perform other operations that don't involve model switching - expect(config.isInFallbackMode()).toBe(false); - }); - - it('should persist switched state throughout session', async () => { - await config.setModel(DEFAULT_GEMINI_FLASH_MODEL); - // Setting state for fallback mode as is expected of clients - config.setFallbackMode(true); - expect(config.isInFallbackMode()).toBe(true); - - // Should remain true even after getting model - config.getModel(); - expect(config.isInFallbackMode()).toBe(true); - }); - }); -}); diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts deleted file mode 100644 index 8c790dd1a..000000000 --- a/packages/core/src/config/models.test.ts +++ /dev/null @@ -1,83 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, it, expect } from 'vitest'; -import { - getEffectiveModel, - DEFAULT_GEMINI_MODEL, - DEFAULT_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_FLASH_LITE_MODEL, -} from './models.js'; - -describe('getEffectiveModel', () => { - describe('When NOT in fallback mode', () => { - const isInFallbackMode = false; - - it('should return the Pro model when Pro is requested', () => { - const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL); - expect(model).toBe(DEFAULT_GEMINI_MODEL); - }); - - it('should return the Flash model when Flash is requested', () => { - const model = getEffectiveModel( - isInFallbackMode, - DEFAULT_GEMINI_FLASH_MODEL, - ); - expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); - }); - - it('should return the Lite model when Lite is requested', () => { - const model = getEffectiveModel( - isInFallbackMode, - DEFAULT_GEMINI_FLASH_LITE_MODEL, - ); - expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); - }); - - it('should return a custom model name when requested', () => { - const customModel = 'custom-model-v1'; - const model = getEffectiveModel(isInFallbackMode, customModel); - expect(model).toBe(customModel); - }); - }); - - describe('When IN fallback mode', () => { - const isInFallbackMode = true; - - it('should downgrade the Pro model to the Flash model', () => { - const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL); - expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); - }); - - it('should return the Flash model when Flash is requested', () => { - const model = getEffectiveModel( - isInFallbackMode, - DEFAULT_GEMINI_FLASH_MODEL, - ); - expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); - }); - - it('should HONOR the Lite model when Lite is requested', () => { - const model = getEffectiveModel( - isInFallbackMode, - DEFAULT_GEMINI_FLASH_LITE_MODEL, - ); - expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); - }); - - it('should HONOR any model with "lite" in its name', () => { - const customLiteModel = 'gemini-2.5-custom-lite-vNext'; - const model = getEffectiveModel(isInFallbackMode, customLiteModel); - expect(model).toBe(customLiteModel); - }); - - it('should downgrade any other custom model to the Flash model', () => { - const customModel = 'custom-model-v1-unlisted'; - const model = getEffectiveModel(isInFallbackMode, customModel); - expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); - }); - }); -}); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index ea7ef2024..a07dec7ce 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -7,46 +7,3 @@ export const DEFAULT_QWEN_MODEL = 'coder-model'; export const DEFAULT_QWEN_FLASH_MODEL = 'coder-model'; export const DEFAULT_QWEN_EMBEDDING_MODEL = 'text-embedding-v4'; - -export const DEFAULT_GEMINI_MODEL = 'coder-model'; -export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash'; -export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite'; - -export const DEFAULT_GEMINI_MODEL_AUTO = 'auto'; - -export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001'; - -// Some thinking models do not default to dynamic thinking which is done by a value of -1 -export const DEFAULT_THINKING_MODE = -1; - -/** - * Determines the effective model to use, applying fallback logic if necessary. - * - * When fallback mode is active, this function enforces the use of the standard - * fallback model. However, it makes an exception for "lite" models (any model - * with "lite" in its name), allowing them to be used to preserve cost savings. - * This ensures that "pro" models are always downgraded, while "lite" model - * requests are honored. - * - * @param isInFallbackMode Whether the application is in fallback mode. - * @param requestedModel The model that was originally requested. - * @returns The effective model name. - */ -export function getEffectiveModel( - isInFallbackMode: boolean, - requestedModel: string, -): string { - // If we are not in fallback mode, simply use the requested model. - if (!isInFallbackMode) { - return requestedModel; - } - - // If a "lite" model is requested, honor it. This allows for variations of - // lite models without needing to list them all as constants. - if (requestedModel.includes('lite')) { - return requestedModel; - } - - // Default fallback for Gemini CLI. - return DEFAULT_GEMINI_FLASH_MODEL; -} diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index f069ce4d5..86de132ba 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -32,7 +32,7 @@ import { type ChatCompressionInfo, } from './turn.js'; import { getCoreSystemPrompt } from './prompts.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { setSimulate429 } from '../utils/testUtils.js'; import { tokenLimit } from './tokenLimits.js'; @@ -302,8 +302,6 @@ describe('Gemini Client (client.ts)', () => { getFileService: vi.fn().mockReturnValue(fileService), getMaxSessionTurns: vi.fn().mockReturnValue(0), getSessionTokenLimit: vi.fn().mockReturnValue(32000), - getQuotaErrorOccurred: vi.fn().mockReturnValue(false), - setQuotaErrorOccurred: vi.fn(), getNoBrowser: vi.fn().mockReturnValue(false), getUsageStatisticsEnabled: vi.fn().mockReturnValue(true), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), @@ -317,8 +315,6 @@ describe('Gemini Client (client.ts)', () => { getModelRouterService: vi.fn().mockReturnValue({ route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }), }), - isInFallbackMode: vi.fn().mockReturnValue(false), - setFallbackMode: vi.fn(), getCliVersion: vi.fn().mockReturnValue('1.0.0'), getChatCompression: vi.fn().mockReturnValue(undefined), getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false), @@ -2262,12 +2258,12 @@ ${JSON.stringify( contents, generationConfig, abortSignal, - DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_QWEN_FLASH_MODEL, ); expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( expect.objectContaining({ - model: DEFAULT_GEMINI_FLASH_MODEL, + model: DEFAULT_QWEN_FLASH_MODEL, config: expect.objectContaining({ abortSignal, systemInstruction: getCoreSystemPrompt(''), @@ -2290,7 +2286,7 @@ ${JSON.stringify( contents, {}, new AbortController().signal, - DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_QWEN_FLASH_MODEL, ); expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({ @@ -2300,7 +2296,7 @@ ${JSON.stringify( }); expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { - model: DEFAULT_GEMINI_FLASH_MODEL, + model: DEFAULT_QWEN_FLASH_MODEL, config: expect.any(Object), contents, }, @@ -2308,28 +2304,7 @@ ${JSON.stringify( ); }); - it('should use the Flash model when fallback mode is active', async () => { - const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; - const generationConfig = { temperature: 0.5 }; - const abortSignal = new AbortController().signal; - const requestedModel = 'gemini-2.5-pro'; // A non-flash model - - // Mock config to be in fallback mode - vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true); - - await client.generateContent( - contents, - generationConfig, - abortSignal, - requestedModel, - ); - - expect(mockGenerateContentFn).toHaveBeenCalledWith( - expect.objectContaining({ - model: DEFAULT_GEMINI_FLASH_MODEL, - }), - 'test-session-id', - ); - }); + // Note: there is currently no "fallback mode" model routing; the model used + // is always the one explicitly requested by the caller. }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 6c62478d0..aaaa98114 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -15,7 +15,6 @@ import type { // Config import { ApprovalMode, type Config } from '../config/config.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; // Core modules import type { ContentGenerator } from './contentGenerator.js'; @@ -542,11 +541,6 @@ export class GeminiClient { } } if (!turn.pendingToolCalls.length && signal && !signal.aborted) { - // Check if next speaker check is needed - if (this.config.getQuotaErrorOccurred()) { - return turn; - } - if (this.config.getSkipNextSpeakerCheck()) { return turn; } @@ -602,14 +596,11 @@ export class GeminiClient { }; const apiCall = () => { - const modelToUse = this.config.isInFallbackMode() - ? DEFAULT_GEMINI_FLASH_MODEL - : model; - currentAttemptModel = modelToUse; + currentAttemptModel = model; return this.getContentGeneratorOrFail().generateContent( { - model: modelToUse, + model, config: requestConfig, contents, }, diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 4b176c989..eef7f5ac8 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -5,7 +5,11 @@ */ import { describe, it, expect, vi } from 'vitest'; -import { createContentGenerator, AuthType } from './contentGenerator.js'; +import { + createContentGenerator, + createContentGeneratorConfig, + AuthType, +} from './contentGenerator.js'; import { GoogleGenAI } from '@google/genai'; import type { Config } from '../config/config.js'; import { LoggingContentGenerator } from './loggingContentGenerator/index.js'; @@ -78,3 +82,32 @@ describe('createContentGenerator', () => { expect(generator).toBeInstanceOf(LoggingContentGenerator); }); }); + +describe('createContentGeneratorConfig', () => { + const mockConfig = { + getProxy: () => undefined, + } as unknown as Config; + + it('should preserve provided fields and set authType for QWEN_OAUTH', () => { + const cfg = createContentGeneratorConfig(mockConfig, AuthType.QWEN_OAUTH, { + model: 'vision-model', + apiKey: 'QWEN_OAUTH_DYNAMIC_TOKEN', + }); + expect(cfg.authType).toBe(AuthType.QWEN_OAUTH); + expect(cfg.model).toBe('vision-model'); + expect(cfg.apiKey).toBe('QWEN_OAUTH_DYNAMIC_TOKEN'); + }); + + it('should not warn or fallback for QWEN_OAUTH (resolution handled by ModelConfigResolver)', () => { + const warnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + const cfg = createContentGeneratorConfig(mockConfig, AuthType.QWEN_OAUTH, { + model: 'some-random-model', + }); + expect(cfg.model).toBe('some-random-model'); + expect(cfg.apiKey).toBeUndefined(); + expect(warnSpy).not.toHaveBeenCalled(); + warnSpy.mockRestore(); + }); +}); diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index f6f83761f..fc36fda3c 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -12,9 +12,24 @@ import type { GenerateContentParameters, GenerateContentResponse, } from '@google/genai'; -import { DEFAULT_QWEN_MODEL } from '../config/models.js'; import type { Config } from '../config/config.js'; import { LoggingContentGenerator } from './loggingContentGenerator/index.js'; +import type { + ConfigSource, + ConfigSourceKind, + ConfigSources, +} from '../utils/configResolver.js'; +import { + getDefaultApiKeyEnvVar, + getDefaultModelEnvVar, + MissingAnthropicBaseUrlEnvError, + MissingApiKeyError, + MissingBaseUrlError, + MissingModelError, + StrictMissingCredentialsError, + StrictMissingModelIdError, +} from '../models/modelConfigErrors.js'; +import { PROVIDER_SOURCED_FIELDS } from '../models/modelsConfig.js'; /** * Interface abstracting the core functionalities for generating content and counting tokens. @@ -48,6 +63,7 @@ export enum AuthType { export type ContentGeneratorConfig = { model: string; apiKey?: string; + apiKeyEnvKey?: string; baseUrl?: string; vertexai?: boolean; authType?: AuthType | undefined; @@ -77,102 +93,178 @@ export type ContentGeneratorConfig = { schemaCompliance?: 'auto' | 'openapi_30'; }; -export function createContentGeneratorConfig( +// Keep the public ContentGeneratorConfigSources API, but reuse the generic +// source-tracking types from utils/configResolver to avoid duplication. +export type ContentGeneratorConfigSourceKind = ConfigSourceKind; +export type ContentGeneratorConfigSource = ConfigSource; +export type ContentGeneratorConfigSources = ConfigSources; + +export type ResolvedContentGeneratorConfig = { + config: ContentGeneratorConfig; + sources: ContentGeneratorConfigSources; +}; + +function setSource( + sources: ContentGeneratorConfigSources, + path: string, + source: ContentGeneratorConfigSource, +): void { + sources[path] = source; +} + +function getSeedSource( + seed: ContentGeneratorConfigSources | undefined, + path: string, +): ContentGeneratorConfigSource | undefined { + return seed?.[path]; +} + +/** + * Resolve ContentGeneratorConfig while tracking the source of each effective field. + * + * This function now primarily validates and finalizes the configuration that has + * already been resolved by ModelConfigResolver. The env fallback logic has been + * moved to the unified resolver to eliminate duplication. + * + * Note: The generationConfig passed here should already be fully resolved with + * proper source tracking from the caller (CLI/SDK layer). + */ +export function resolveContentGeneratorConfigWithSources( config: Config, authType: AuthType | undefined, generationConfig?: Partial, -): ContentGeneratorConfig { - let newContentGeneratorConfig: Partial = { + seedSources?: ContentGeneratorConfigSources, + options?: { strictModelProvider?: boolean }, +): ResolvedContentGeneratorConfig { + const sources: ContentGeneratorConfigSources = { ...(seedSources || {}) }; + const strictModelProvider = options?.strictModelProvider === true; + + // Build config with computed fields + const newContentGeneratorConfig: Partial = { ...(generationConfig || {}), authType, proxy: config?.getProxy(), }; - if (authType === AuthType.QWEN_OAUTH) { - // For Qwen OAuth, we'll handle the API key dynamically in createContentGenerator - // Set a special marker to indicate this is Qwen OAuth - return { - ...newContentGeneratorConfig, - model: DEFAULT_QWEN_MODEL, - apiKey: 'QWEN_OAUTH_DYNAMIC_TOKEN', - } as ContentGeneratorConfig; + // Set sources for computed fields + setSource(sources, 'authType', { + kind: 'computed', + detail: 'provided by caller', + }); + if (config?.getProxy()) { + setSource(sources, 'proxy', { + kind: 'computed', + detail: 'Config.getProxy()', + }); } - if (authType === AuthType.USE_OPENAI) { - newContentGeneratorConfig = { - ...newContentGeneratorConfig, - apiKey: newContentGeneratorConfig.apiKey || process.env['OPENAI_API_KEY'], - baseUrl: - newContentGeneratorConfig.baseUrl || process.env['OPENAI_BASE_URL'], - model: newContentGeneratorConfig.model || process.env['OPENAI_MODEL'], - }; + // Preserve seed sources for fields that were passed in + const seedOrUnknown = (path: string): ContentGeneratorConfigSource => + getSeedSource(seedSources, path) ?? { kind: 'unknown' }; - if (!newContentGeneratorConfig.apiKey) { - throw new Error('OPENAI_API_KEY environment variable not found.'); - } - - return { - ...newContentGeneratorConfig, - model: newContentGeneratorConfig?.model || 'qwen3-coder-plus', - } as ContentGeneratorConfig; - } - - if (authType === AuthType.USE_ANTHROPIC) { - newContentGeneratorConfig = { - ...newContentGeneratorConfig, - apiKey: - newContentGeneratorConfig.apiKey || process.env['ANTHROPIC_API_KEY'], - baseUrl: - newContentGeneratorConfig.baseUrl || process.env['ANTHROPIC_BASE_URL'], - model: newContentGeneratorConfig.model || process.env['ANTHROPIC_MODEL'], - }; - - if (!newContentGeneratorConfig.apiKey) { - throw new Error('ANTHROPIC_API_KEY environment variable not found.'); - } - - if (!newContentGeneratorConfig.baseUrl) { - throw new Error('ANTHROPIC_BASE_URL environment variable not found.'); - } - - if (!newContentGeneratorConfig.model) { - throw new Error('ANTHROPIC_MODEL environment variable not found.'); + for (const field of PROVIDER_SOURCED_FIELDS) { + if (generationConfig && field in generationConfig && !sources[field]) { + setSource(sources, field, seedOrUnknown(field)); } } - if (authType === AuthType.USE_GEMINI) { - newContentGeneratorConfig = { - ...newContentGeneratorConfig, - apiKey: newContentGeneratorConfig.apiKey || process.env['GEMINI_API_KEY'], - model: newContentGeneratorConfig.model || process.env['GEMINI_MODEL'], - }; + // Validate required fields based on authType. This does not perform any + // fallback resolution (resolution is handled by ModelConfigResolver). + const validation = validateModelConfig( + newContentGeneratorConfig as ContentGeneratorConfig, + strictModelProvider, + ); + if (!validation.valid) { + throw new Error(validation.errors.map((e) => e.message).join('\n')); + } - if (!newContentGeneratorConfig.apiKey) { - throw new Error('GEMINI_API_KEY environment variable not found.'); - } + return { + config: newContentGeneratorConfig as ContentGeneratorConfig, + sources, + }; +} - if (!newContentGeneratorConfig.model) { - throw new Error('GEMINI_MODEL environment variable not found.'); +export interface ModelConfigValidationResult { + valid: boolean; + errors: Error[]; +} + +/** + * Validate a resolved model configuration. + * This is the single validation entry point used across Core. + */ +export function validateModelConfig( + config: ContentGeneratorConfig, + isStrictModelProvider: boolean = false, +): ModelConfigValidationResult { + const errors: Error[] = []; + + // Qwen OAuth doesn't need validation - it uses dynamic tokens + if (config.authType === AuthType.QWEN_OAUTH) { + return { valid: true, errors: [] }; + } + + // API key is required for all other auth types + if (!config.apiKey) { + if (isStrictModelProvider) { + errors.push( + new StrictMissingCredentialsError( + config.authType, + config.model, + config.apiKeyEnvKey, + ), + ); + } else { + const envKey = + config.apiKeyEnvKey || getDefaultApiKeyEnvVar(config.authType); + errors.push( + new MissingApiKeyError({ + authType: config.authType, + model: config.model, + baseUrl: config.baseUrl, + envKey, + }), + ); } } - if (authType === AuthType.USE_VERTEX_AI) { - newContentGeneratorConfig = { - ...newContentGeneratorConfig, - apiKey: newContentGeneratorConfig.apiKey || process.env['GOOGLE_API_KEY'], - model: newContentGeneratorConfig.model || process.env['GOOGLE_MODEL'], - }; - - if (!newContentGeneratorConfig.apiKey) { - throw new Error('GOOGLE_API_KEY environment variable not found.'); - } - - if (!newContentGeneratorConfig.model) { - throw new Error('GOOGLE_MODEL environment variable not found.'); + // Model is required + if (!config.model) { + if (isStrictModelProvider) { + errors.push(new StrictMissingModelIdError(config.authType)); + } else { + const envKey = getDefaultModelEnvVar(config.authType); + errors.push(new MissingModelError({ authType: config.authType, envKey })); } } - return newContentGeneratorConfig as ContentGeneratorConfig; + // Explicit baseUrl is required for Anthropic; Migrated from existing code. + if (config.authType === AuthType.USE_ANTHROPIC && !config.baseUrl) { + if (isStrictModelProvider) { + errors.push( + new MissingBaseUrlError({ + authType: config.authType, + model: config.model, + }), + ); + } else if (config.authType === AuthType.USE_ANTHROPIC) { + errors.push(new MissingAnthropicBaseUrlEnvError()); + } + } + + return { valid: errors.length === 0, errors }; +} + +export function createContentGeneratorConfig( + config: Config, + authType: AuthType | undefined, + generationConfig?: Partial, +): ContentGeneratorConfig { + return resolveContentGeneratorConfigWithSources( + config, + authType, + generationConfig, + ).config; } export async function createContentGenerator( @@ -180,11 +272,12 @@ export async function createContentGenerator( gcConfig: Config, isInitialAuth?: boolean, ): Promise { - if (config.authType === AuthType.USE_OPENAI) { - if (!config.apiKey) { - throw new Error('OPENAI_API_KEY environment variable not found.'); - } + const validation = validateModelConfig(config, false); + if (!validation.valid) { + throw new Error(validation.errors.map((e) => e.message).join('\n')); + } + if (config.authType === AuthType.USE_OPENAI) { // Import OpenAIContentGenerator dynamically to avoid circular dependencies const { createOpenAIContentGenerator } = await import( './openaiContentGenerator/index.js' @@ -223,10 +316,6 @@ export async function createContentGenerator( } if (config.authType === AuthType.USE_ANTHROPIC) { - if (!config.apiKey) { - throw new Error('ANTHROPIC_API_KEY environment variable not found.'); - } - const { createAnthropicContentGenerator } = await import( './anthropicContentGenerator/index.js' ); diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index a77fc6707..20e884548 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -20,7 +20,6 @@ import { } from './geminiChat.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { AuthType } from './contentGenerator.js'; import { type RetryOptions } from '../utils/retry.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; @@ -117,10 +116,6 @@ describe('GeminiChat', () => { }), getModel: vi.fn().mockReturnValue('gemini-pro'), setModel: vi.fn(), - isInFallbackMode: vi.fn().mockReturnValue(false), - getQuotaErrorOccurred: vi.fn().mockReturnValue(false), - setQuotaErrorOccurred: vi.fn(), - flashFallbackHandler: undefined, getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), getCliVersion: vi.fn().mockReturnValue('1.0.0'), storage: { @@ -1349,9 +1344,8 @@ describe('GeminiChat', () => { ], } as unknown as GenerateContentResponse; - it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => { + it('should pass the requested model through to generateContentStream', async () => { vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro'); - vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); vi.mocked(mockContentGenerator.generateContentStream).mockImplementation( async () => (async function* () { @@ -1370,7 +1364,7 @@ describe('GeminiChat', () => { expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith( expect.objectContaining({ - model: DEFAULT_GEMINI_FLASH_MODEL, + model: 'test-model', }), 'prompt-id-res3', ); @@ -1422,9 +1416,6 @@ describe('GeminiChat', () => { authType, }); - const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode'); - isInFallbackModeSpy.mockReturnValue(false); - vi.mocked(mockContentGenerator.generateContentStream) .mockRejectedValueOnce(error429) // Attempt 1 fails .mockResolvedValueOnce( @@ -1441,10 +1432,7 @@ describe('GeminiChat', () => { })(), ); - mockHandleFallback.mockImplementation(async () => { - isInFallbackModeSpy.mockReturnValue(true); - return true; // Signal retry - }); + mockHandleFallback.mockImplementation(async () => true); const stream = await chat.sendMessageStream( 'test-model', diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 04add3419..d4aaee25a 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -19,10 +19,6 @@ import type { import { ApiError, createUserContent } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; import type { Config } from '../config/config.js'; -import { - DEFAULT_GEMINI_FLASH_MODEL, - getEffectiveModel, -} from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; import type { StructuredError } from './turn.js'; import { @@ -352,31 +348,15 @@ export class GeminiChat { params: SendMessageParameters, prompt_id: string, ): Promise> { - const apiCall = () => { - const modelToUse = getEffectiveModel( - this.config.isInFallbackMode(), - model, - ); - - if ( - this.config.getQuotaErrorOccurred() && - modelToUse === DEFAULT_GEMINI_FLASH_MODEL - ) { - throw new Error( - 'Please submit a new query to continue with the Flash model.', - ); - } - - return this.config.getContentGenerator().generateContentStream( + const apiCall = () => + this.config.getContentGenerator().generateContentStream( { - model: modelToUse, + model, contents: requestContents, config: { ...this.generationConfig, ...params.config }, }, prompt_id, ); - }; - const onPersistent429Callback = async ( authType?: string, error?: unknown, diff --git a/packages/core/src/core/openaiContentGenerator/pipeline.test.ts b/packages/core/src/core/openaiContentGenerator/pipeline.test.ts index 93adcb090..d5220b080 100644 --- a/packages/core/src/core/openaiContentGenerator/pipeline.test.ts +++ b/packages/core/src/core/openaiContentGenerator/pipeline.test.ts @@ -46,6 +46,7 @@ describe('ContentGenerationPipeline', () => { // Mock converter mockConverter = { + setModel: vi.fn(), convertGeminiRequestToOpenAI: vi.fn(), convertOpenAIResponseToGemini: vi.fn(), convertOpenAIChunkToGemini: vi.fn(), @@ -99,6 +100,7 @@ describe('ContentGenerationPipeline', () => { describe('constructor', () => { it('should initialize with correct configuration', () => { expect(mockProvider.buildClient).toHaveBeenCalled(); + // Converter is constructed once and the model is updated per-request via setModel(). expect(OpenAIContentConverter).toHaveBeenCalledWith( 'test-model', undefined, @@ -144,6 +146,9 @@ describe('ContentGenerationPipeline', () => { // Assert expect(result).toBe(mockGeminiResponse); + expect( + (mockConverter as unknown as { setModel: Mock }).setModel, + ).toHaveBeenCalledWith('test-model'); expect(mockConverter.convertGeminiRequestToOpenAI).toHaveBeenCalledWith( request, ); @@ -164,6 +169,53 @@ describe('ContentGenerationPipeline', () => { ); }); + it('should ignore request.model override and always use configured model', async () => { + // Arrange + const request: GenerateContentParameters = { + model: 'override-model', + contents: [{ parts: [{ text: 'Hello' }], role: 'user' }], + }; + const userPromptId = 'test-prompt-id'; + + const mockMessages = [ + { role: 'user', content: 'Hello' }, + ] as OpenAI.Chat.ChatCompletionMessageParam[]; + const mockOpenAIResponse = { + id: 'response-id', + choices: [ + { message: { content: 'Hello response' }, finish_reason: 'stop' }, + ], + created: Date.now(), + model: 'override-model', + } as OpenAI.Chat.ChatCompletion; + const mockGeminiResponse = new GenerateContentResponse(); + + (mockConverter.convertGeminiRequestToOpenAI as Mock).mockReturnValue( + mockMessages, + ); + (mockConverter.convertOpenAIResponseToGemini as Mock).mockReturnValue( + mockGeminiResponse, + ); + (mockClient.chat.completions.create as Mock).mockResolvedValue( + mockOpenAIResponse, + ); + + // Act + const result = await pipeline.execute(request, userPromptId); + + // Assert + expect(result).toBe(mockGeminiResponse); + expect( + (mockConverter as unknown as { setModel: Mock }).setModel, + ).toHaveBeenCalledWith('test-model'); + expect(mockClient.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'test-model', + }), + expect.any(Object), + ); + }); + it('should handle tools in request', async () => { // Arrange const request: GenerateContentParameters = { @@ -217,6 +269,9 @@ describe('ContentGenerationPipeline', () => { // Assert expect(result).toBe(mockGeminiResponse); + expect( + (mockConverter as unknown as { setModel: Mock }).setModel, + ).toHaveBeenCalledWith('test-model'); expect(mockConverter.convertGeminiToolsToOpenAI).toHaveBeenCalledWith( request.config!.tools, ); diff --git a/packages/core/src/core/openaiContentGenerator/pipeline.ts b/packages/core/src/core/openaiContentGenerator/pipeline.ts index ef27a7798..0f00ecb30 100644 --- a/packages/core/src/core/openaiContentGenerator/pipeline.ts +++ b/packages/core/src/core/openaiContentGenerator/pipeline.ts @@ -40,10 +40,16 @@ export class ContentGenerationPipeline { request: GenerateContentParameters, userPromptId: string, ): Promise { + // For OpenAI-compatible providers, the configured model is the single source of truth. + // We intentionally ignore request.model because upstream callers may pass a model string + // that is not valid/available for the OpenAI-compatible backend. + const effectiveModel = this.contentGeneratorConfig.model; + this.converter.setModel(effectiveModel); return this.executeWithErrorHandling( request, userPromptId, false, + effectiveModel, async (openaiRequest) => { const openaiResponse = (await this.client.chat.completions.create( openaiRequest, @@ -64,10 +70,13 @@ export class ContentGenerationPipeline { request: GenerateContentParameters, userPromptId: string, ): Promise> { + const effectiveModel = this.contentGeneratorConfig.model; + this.converter.setModel(effectiveModel); return this.executeWithErrorHandling( request, userPromptId, true, + effectiveModel, async (openaiRequest, context) => { // Stage 1: Create OpenAI stream const stream = (await this.client.chat.completions.create( @@ -224,12 +233,13 @@ export class ContentGenerationPipeline { request: GenerateContentParameters, userPromptId: string, streaming: boolean = false, + effectiveModel: string, ): Promise { const messages = this.converter.convertGeminiRequestToOpenAI(request); // Apply provider-specific enhancements const baseRequest: OpenAI.Chat.ChatCompletionCreateParams = { - model: this.contentGeneratorConfig.model, + model: effectiveModel, messages, ...this.buildGenerateContentConfig(request), }; @@ -342,18 +352,24 @@ export class ContentGenerationPipeline { request: GenerateContentParameters, userPromptId: string, isStreaming: boolean, + effectiveModel: string, executor: ( openaiRequest: OpenAI.Chat.ChatCompletionCreateParams, context: RequestContext, ) => Promise, ): Promise { - const context = this.createRequestContext(userPromptId, isStreaming); + const context = this.createRequestContext( + userPromptId, + isStreaming, + effectiveModel, + ); try { const openaiRequest = await this.buildRequest( request, userPromptId, isStreaming, + effectiveModel, ); const result = await executor(openaiRequest, context); @@ -385,10 +401,11 @@ export class ContentGenerationPipeline { private createRequestContext( userPromptId: string, isStreaming: boolean, + effectiveModel: string, ): RequestContext { return { userPromptId, - model: this.contentGeneratorConfig.model, + model: effectiveModel, authType: this.contentGeneratorConfig.authType || 'unknown', startTime: Date.now(), duration: 0, diff --git a/packages/core/src/fallback/types.ts b/packages/core/src/fallback/types.ts deleted file mode 100644 index 654312337..000000000 --- a/packages/core/src/fallback/types.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -/** - * Defines the intent returned by the UI layer during a fallback scenario. - */ -export type FallbackIntent = - | 'retry' // Immediately retry the current request with the fallback model. - | 'stop' // Switch to fallback for future requests, but stop the current request. - | 'auth'; // Stop the current request; user intends to change authentication. - -/** - * The interface for the handler provided by the UI layer (e.g., the CLI) - * to interact with the user during a fallback scenario. - */ -export type FallbackModelHandler = ( - failedModel: string, - fallbackModel: string, - error?: unknown, -) => Promise; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 7f7bd115b..60e66b19e 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -9,6 +9,30 @@ export * from './config/config.js'; export * from './output/types.js'; export * from './output/json-formatter.js'; +// Export models +export { + type ModelCapabilities, + type ModelGenerationConfig, + type ModelConfig as ProviderModelConfig, + type ModelProvidersConfig, + type ResolvedModelConfig, + type AvailableModel, + type ModelSwitchMetadata, + QWEN_OAUTH_MODELS, + ModelRegistry, + ModelsConfig, + type ModelsConfigOptions, + type OnModelChangeCallback, + // Model configuration resolver + resolveModelConfig, + validateModelConfig, + type ModelConfigSourcesInput, + type ModelConfigCliInput, + type ModelConfigSettingsInput, + type ModelConfigResolutionResult, + type ModelConfigValidationResult, +} from './models/index.js'; + // Export Core Logic export * from './core/client.js'; export * from './core/contentGenerator.js'; @@ -21,8 +45,6 @@ export * from './core/geminiRequest.js'; export * from './core/coreToolScheduler.js'; export * from './core/nonInteractiveToolExecutor.js'; -export * from './fallback/types.js'; - export * from './qwen/qwenOAuth2.js'; // Export utilities @@ -55,6 +77,9 @@ export * from './utils/projectSummary.js'; export * from './utils/promptIdContext.js'; export * from './utils/thoughtUtils.js'; +// Config resolution utilities +export * from './utils/configResolver.js'; + // Export services export * from './services/fileDiscoveryService.js'; export * from './services/gitService.js'; diff --git a/packages/core/src/models/constants.ts b/packages/core/src/models/constants.ts new file mode 100644 index 000000000..9dd69620c --- /dev/null +++ b/packages/core/src/models/constants.ts @@ -0,0 +1,134 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DEFAULT_QWEN_MODEL } from '../config/models.js'; + +import type { ModelConfig } from './types.js'; + +type AuthType = import('../core/contentGenerator.js').AuthType; +type ContentGeneratorConfig = + import('../core/contentGenerator.js').ContentGeneratorConfig; + +/** + * Field keys for model-scoped generation config. + * + * Kept in a small standalone module to avoid circular deps. The `import('...')` + * usage is type-only and does not emit runtime imports. + */ +export const MODEL_GENERATION_CONFIG_FIELDS = [ + 'samplingParams', + 'timeout', + 'maxRetries', + 'disableCacheControl', + 'schemaCompliance', + 'reasoning', +] as const satisfies ReadonlyArray; + +/** + * Credential-related fields that are part of ContentGeneratorConfig + * but not ModelGenerationConfig. + */ +export const CREDENTIAL_FIELDS = [ + 'model', + 'apiKey', + 'apiKeyEnvKey', + 'baseUrl', +] as const satisfies ReadonlyArray; + +/** + * All provider-sourced fields that need to be tracked for source attribution + * and cleared when switching from provider to manual credentials. + */ +export const PROVIDER_SOURCED_FIELDS = [ + ...CREDENTIAL_FIELDS, + ...MODEL_GENERATION_CONFIG_FIELDS, +] as const; + +/** + * Environment variable mappings per authType. + */ +export interface AuthEnvMapping { + apiKey: string[]; + baseUrl: string[]; + model: string[]; +} + +export const AUTH_ENV_MAPPINGS = { + openai: { + apiKey: ['OPENAI_API_KEY'], + baseUrl: ['OPENAI_BASE_URL'], + model: ['OPENAI_MODEL', 'QWEN_MODEL'], + }, + anthropic: { + apiKey: ['ANTHROPIC_API_KEY'], + baseUrl: ['ANTHROPIC_BASE_URL'], + model: ['ANTHROPIC_MODEL'], + }, + gemini: { + apiKey: ['GEMINI_API_KEY'], + baseUrl: [], + model: ['GEMINI_MODEL'], + }, + 'vertex-ai': { + apiKey: ['GOOGLE_API_KEY'], + baseUrl: [], + model: ['GOOGLE_MODEL'], + }, + 'qwen-oauth': { + apiKey: [], + baseUrl: [], + model: [], + }, +} as const satisfies Record; + +export const DEFAULT_MODELS = { + openai: 'qwen3-coder-plus', + 'qwen-oauth': DEFAULT_QWEN_MODEL, +} as Partial>; + +export const QWEN_OAUTH_ALLOWED_MODELS = [ + DEFAULT_QWEN_MODEL, + 'vision-model', +] as const; + +/** + * Hard-coded Qwen OAuth models that are always available. + * These cannot be overridden by user configuration. + */ +export const QWEN_OAUTH_MODELS: ModelConfig[] = [ + { + id: 'coder-model', + name: 'Qwen Coder', + description: + 'The latest Qwen Coder model from Alibaba Cloud ModelStudio (version: qwen3-coder-plus-2025-09-23)', + capabilities: { vision: false }, + generationConfig: { + samplingParams: { + temperature: 0.7, + top_p: 0.9, + max_tokens: 8192, + }, + timeout: 60000, + maxRetries: 3, + }, + }, + { + id: 'vision-model', + name: 'Qwen Vision', + description: + 'The latest Qwen Vision model from Alibaba Cloud ModelStudio (version: qwen3-vl-plus-2025-09-23)', + capabilities: { vision: true }, + generationConfig: { + samplingParams: { + temperature: 0.7, + top_p: 0.9, + max_tokens: 8192, + }, + timeout: 60000, + maxRetries: 3, + }, + }, +]; diff --git a/packages/core/src/models/index.ts b/packages/core/src/models/index.ts new file mode 100644 index 000000000..7525074a5 --- /dev/null +++ b/packages/core/src/models/index.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +export { + type ModelCapabilities, + type ModelGenerationConfig, + type ModelConfig, + type ModelProvidersConfig, + type ResolvedModelConfig, + type AvailableModel, + type ModelSwitchMetadata, +} from './types.js'; + +export { ModelRegistry } from './modelRegistry.js'; + +export { + ModelsConfig, + type ModelsConfigOptions, + type OnModelChangeCallback, +} from './modelsConfig.js'; + +export { + AUTH_ENV_MAPPINGS, + CREDENTIAL_FIELDS, + DEFAULT_MODELS, + MODEL_GENERATION_CONFIG_FIELDS, + PROVIDER_SOURCED_FIELDS, + QWEN_OAUTH_ALLOWED_MODELS, + QWEN_OAUTH_MODELS, +} from './constants.js'; + +// Model configuration resolver +export { + resolveModelConfig, + validateModelConfig, + type ModelConfigSourcesInput, + type ModelConfigCliInput, + type ModelConfigSettingsInput, + type ModelConfigResolutionResult, + type ModelConfigValidationResult, +} from './modelConfigResolver.js'; diff --git a/packages/core/src/models/modelConfigErrors.ts b/packages/core/src/models/modelConfigErrors.ts new file mode 100644 index 000000000..3504793bd --- /dev/null +++ b/packages/core/src/models/modelConfigErrors.ts @@ -0,0 +1,125 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +export function getDefaultApiKeyEnvVar(authType: string | undefined): string { + switch (authType) { + case 'openai': + return 'OPENAI_API_KEY'; + case 'anthropic': + return 'ANTHROPIC_API_KEY'; + case 'gemini': + return 'GEMINI_API_KEY'; + case 'vertex-ai': + return 'GOOGLE_API_KEY'; + default: + return 'API_KEY'; + } +} + +export function getDefaultModelEnvVar(authType: string | undefined): string { + switch (authType) { + case 'openai': + return 'OPENAI_MODEL'; + case 'anthropic': + return 'ANTHROPIC_MODEL'; + case 'gemini': + return 'GEMINI_MODEL'; + case 'vertex-ai': + return 'GOOGLE_MODEL'; + default: + return 'MODEL'; + } +} + +export abstract class ModelConfigError extends Error { + abstract readonly code: string; + + protected constructor(message: string) { + super(message); + this.name = new.target.name; + Object.setPrototypeOf(this, new.target.prototype); + } +} + +export class StrictMissingCredentialsError extends ModelConfigError { + readonly code = 'STRICT_MISSING_CREDENTIALS'; + + constructor( + authType: string | undefined, + model: string | undefined, + envKey?: string, + ) { + const providerKey = authType || '(unknown)'; + const modelName = model || '(unknown)'; + super( + `Missing credentials for modelProviders model '${modelName}'. ` + + (envKey + ? `Current configured envKey: '${envKey}'. Set that environment variable, or update modelProviders.${providerKey}[].envKey.` + : `Configure modelProviders.${providerKey}[].envKey and set that environment variable.`), + ); + } +} + +export class StrictMissingModelIdError extends ModelConfigError { + readonly code = 'STRICT_MISSING_MODEL_ID'; + + constructor(authType: string | undefined) { + super( + `Missing model id for strict modelProviders resolution (authType: ${authType}).`, + ); + } +} + +export class MissingApiKeyError extends ModelConfigError { + readonly code = 'MISSING_API_KEY'; + + constructor(params: { + authType: string | undefined; + model: string | undefined; + baseUrl: string | undefined; + envKey: string; + }) { + super( + `Missing API key for ${params.authType} auth. ` + + `Current model: '${params.model || '(unknown)'}', baseUrl: '${params.baseUrl || '(default)'}'. ` + + `Provide an API key via settings (security.auth.apiKey), ` + + `or set the environment variable '${params.envKey}'.`, + ); + } +} + +export class MissingModelError extends ModelConfigError { + readonly code = 'MISSING_MODEL'; + + constructor(params: { authType: string | undefined; envKey: string }) { + super( + `Missing model for ${params.authType} auth. ` + + `Set the environment variable '${params.envKey}'.`, + ); + } +} + +export class MissingBaseUrlError extends ModelConfigError { + readonly code = 'MISSING_BASE_URL'; + + constructor(params: { + authType: string | undefined; + model: string | undefined; + }) { + super( + `Missing baseUrl for modelProviders model '${params.model || '(unknown)'}' (authType: ${params.authType}). ` + + `Configure modelProviders.${params.authType || '(unknown)'}[].baseUrl.`, + ); + } +} + +export class MissingAnthropicBaseUrlEnvError extends ModelConfigError { + readonly code = 'MISSING_ANTHROPIC_BASE_URL_ENV'; + + constructor() { + super('ANTHROPIC_BASE_URL environment variable not found.'); + } +} diff --git a/packages/core/src/models/modelConfigResolver.test.ts b/packages/core/src/models/modelConfigResolver.test.ts new file mode 100644 index 000000000..b7aa8c29b --- /dev/null +++ b/packages/core/src/models/modelConfigResolver.test.ts @@ -0,0 +1,355 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { + resolveModelConfig, + validateModelConfig, +} from './modelConfigResolver.js'; +import { AuthType } from '../core/contentGenerator.js'; +import { DEFAULT_QWEN_MODEL } from '../config/models.js'; + +describe('modelConfigResolver', () => { + describe('resolveModelConfig', () => { + describe('OpenAI auth type', () => { + it('resolves from CLI with highest priority', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: { + model: 'cli-model', + apiKey: 'cli-key', + baseUrl: 'https://cli.example.com', + }, + settings: { + model: 'settings-model', + apiKey: 'settings-key', + baseUrl: 'https://settings.example.com', + }, + env: { + OPENAI_MODEL: 'env-model', + OPENAI_API_KEY: 'env-key', + OPENAI_BASE_URL: 'https://env.example.com', + }, + }); + + expect(result.config.model).toBe('cli-model'); + expect(result.config.apiKey).toBe('cli-key'); + expect(result.config.baseUrl).toBe('https://cli.example.com'); + + expect(result.sources['model'].kind).toBe('cli'); + expect(result.sources['apiKey'].kind).toBe('cli'); + expect(result.sources['baseUrl'].kind).toBe('cli'); + }); + + it('falls back to env when CLI not provided', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: { + model: 'settings-model', + }, + env: { + OPENAI_MODEL: 'env-model', + OPENAI_API_KEY: 'env-key', + }, + }); + + expect(result.config.model).toBe('env-model'); + expect(result.config.apiKey).toBe('env-key'); + + expect(result.sources['model'].kind).toBe('env'); + expect(result.sources['apiKey'].kind).toBe('env'); + }); + + it('falls back to settings when env not provided', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: { + model: 'settings-model', + apiKey: 'settings-key', + baseUrl: 'https://settings.example.com', + }, + env: {}, + }); + + expect(result.config.model).toBe('settings-model'); + expect(result.config.apiKey).toBe('settings-key'); + expect(result.config.baseUrl).toBe('https://settings.example.com'); + + expect(result.sources['model'].kind).toBe('settings'); + expect(result.sources['apiKey'].kind).toBe('settings'); + expect(result.sources['baseUrl'].kind).toBe('settings'); + }); + + it('uses default model when nothing provided', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: {}, + env: { + OPENAI_API_KEY: 'some-key', // need key to be valid + }, + }); + + expect(result.config.model).toBe('qwen3-coder-plus'); + expect(result.sources['model'].kind).toBe('default'); + }); + + it('prioritizes modelProvider over CLI', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: { + model: 'cli-model', + }, + settings: {}, + env: { + MY_CUSTOM_KEY: 'provider-key', + }, + modelProvider: { + id: 'provider-model', + name: 'Provider Model', + authType: AuthType.USE_OPENAI, + envKey: 'MY_CUSTOM_KEY', + baseUrl: 'https://provider.example.com', + generationConfig: {}, + capabilities: {}, + }, + }); + + expect(result.config.model).toBe('provider-model'); + expect(result.config.apiKey).toBe('provider-key'); + expect(result.config.baseUrl).toBe('https://provider.example.com'); + + expect(result.sources['model'].kind).toBe('modelProviders'); + expect(result.sources['apiKey'].kind).toBe('env'); + expect(result.sources['apiKey'].via?.kind).toBe('modelProviders'); + }); + + it('reads QWEN_MODEL as fallback for OPENAI_MODEL', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: {}, + env: { + QWEN_MODEL: 'qwen-model', + OPENAI_API_KEY: 'key', + }, + }); + + expect(result.config.model).toBe('qwen-model'); + expect(result.sources['model'].envKey).toBe('QWEN_MODEL'); + }); + }); + + describe('Qwen OAuth auth type', () => { + it('uses default model for Qwen OAuth', () => { + const result = resolveModelConfig({ + authType: AuthType.QWEN_OAUTH, + cli: {}, + settings: {}, + env: {}, + }); + + expect(result.config.model).toBe(DEFAULT_QWEN_MODEL); + expect(result.config.apiKey).toBe('QWEN_OAUTH_DYNAMIC_TOKEN'); + expect(result.sources['apiKey'].kind).toBe('computed'); + }); + + it('allows vision-model for Qwen OAuth', () => { + const result = resolveModelConfig({ + authType: AuthType.QWEN_OAUTH, + cli: { + model: 'vision-model', + }, + settings: {}, + env: {}, + }); + + expect(result.config.model).toBe('vision-model'); + expect(result.sources['model'].kind).toBe('cli'); + }); + + it('warns and falls back for unsupported Qwen OAuth models', () => { + const result = resolveModelConfig({ + authType: AuthType.QWEN_OAUTH, + cli: { + model: 'unsupported-model', + }, + settings: {}, + env: {}, + }); + + expect(result.config.model).toBe(DEFAULT_QWEN_MODEL); + expect(result.warnings).toHaveLength(1); + expect(result.warnings[0]).toContain('unsupported-model'); + }); + }); + + describe('Anthropic auth type', () => { + it('resolves Anthropic config from env', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_ANTHROPIC, + cli: {}, + settings: {}, + env: { + ANTHROPIC_API_KEY: 'anthropic-key', + ANTHROPIC_BASE_URL: 'https://anthropic.example.com', + ANTHROPIC_MODEL: 'claude-3', + }, + }); + + expect(result.config.model).toBe('claude-3'); + expect(result.config.apiKey).toBe('anthropic-key'); + expect(result.config.baseUrl).toBe('https://anthropic.example.com'); + }); + }); + + describe('generation config resolution', () => { + it('merges generation config from settings', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: { + apiKey: 'key', + generationConfig: { + timeout: 60000, + maxRetries: 5, + samplingParams: { + temperature: 0.7, + }, + }, + }, + env: {}, + }); + + expect(result.config.timeout).toBe(60000); + expect(result.config.maxRetries).toBe(5); + expect(result.config.samplingParams?.temperature).toBe(0.7); + + expect(result.sources['timeout'].kind).toBe('settings'); + expect(result.sources['samplingParams'].kind).toBe('settings'); + }); + + it('modelProvider config overrides settings', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: { + generationConfig: { + timeout: 30000, + }, + }, + env: { + MY_KEY: 'key', + }, + modelProvider: { + id: 'model', + name: 'Model', + authType: AuthType.USE_OPENAI, + envKey: 'MY_KEY', + baseUrl: 'https://api.example.com', + generationConfig: { + timeout: 60000, + }, + capabilities: {}, + }, + }); + + expect(result.config.timeout).toBe(60000); + expect(result.sources['timeout'].kind).toBe('modelProviders'); + }); + }); + + describe('proxy handling', () => { + it('includes proxy in config when provided', () => { + const result = resolveModelConfig({ + authType: AuthType.USE_OPENAI, + cli: {}, + settings: { apiKey: 'key' }, + env: {}, + proxy: 'http://proxy.example.com:8080', + }); + + expect(result.config.proxy).toBe('http://proxy.example.com:8080'); + expect(result.sources['proxy'].kind).toBe('computed'); + }); + }); + }); + + describe('validateModelConfig', () => { + it('passes for valid OpenAI config', () => { + const result = validateModelConfig({ + authType: AuthType.USE_OPENAI, + model: 'gpt-4', + apiKey: 'sk-xxx', + }); + + expect(result.valid).toBe(true); + expect(result.errors).toHaveLength(0); + }); + + it('fails when API key missing', () => { + const result = validateModelConfig({ + authType: AuthType.USE_OPENAI, + model: 'gpt-4', + }); + + expect(result.valid).toBe(false); + expect(result.errors).toHaveLength(1); + expect(result.errors[0].message).toContain('Missing API key'); + }); + + it('fails when model missing', () => { + const result = validateModelConfig({ + authType: AuthType.USE_OPENAI, + model: '', + apiKey: 'sk-xxx', + }); + + expect(result.valid).toBe(false); + expect(result.errors).toHaveLength(1); + expect(result.errors[0].message).toContain('Missing model'); + }); + + it('always passes for Qwen OAuth', () => { + const result = validateModelConfig({ + authType: AuthType.QWEN_OAUTH, + model: DEFAULT_QWEN_MODEL, + apiKey: 'QWEN_OAUTH_DYNAMIC_TOKEN', + }); + + expect(result.valid).toBe(true); + }); + + it('requires baseUrl for Anthropic', () => { + const result = validateModelConfig({ + authType: AuthType.USE_ANTHROPIC, + model: 'claude-3', + apiKey: 'key', + // missing baseUrl + }); + + expect(result.valid).toBe(false); + expect(result.errors[0].message).toContain('ANTHROPIC_BASE_URL'); + }); + + it('uses strict error messages for modelProvider', () => { + const result = validateModelConfig( + { + authType: AuthType.USE_OPENAI, + model: 'my-model', + // missing apiKey + }, + true, // isStrictModelProvider + ); + + expect(result.valid).toBe(false); + expect(result.errors[0].message).toContain('modelProviders'); + expect(result.errors[0].message).toContain('envKey'); + }); + }); +}); diff --git a/packages/core/src/models/modelConfigResolver.ts b/packages/core/src/models/modelConfigResolver.ts new file mode 100644 index 000000000..dc10fa3e8 --- /dev/null +++ b/packages/core/src/models/modelConfigResolver.ts @@ -0,0 +1,362 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * ModelConfigResolver - Unified resolver for model-related configuration. + * + * This module consolidates all model configuration resolution logic, + * eliminating duplicate code between CLI and Core layers. + * + * Configuration priority (highest to lowest): + * 1. modelProvider - Explicit selection from ModelProviders config + * 2. CLI arguments - Command line flags (--model, --openaiApiKey, etc.) + * 3. Environment variables - OPENAI_API_KEY, OPENAI_MODEL, etc. + * 4. Settings - User/workspace settings file + * 5. Defaults - Built-in default values + */ + +import { AuthType } from '../core/contentGenerator.js'; +import type { ContentGeneratorConfig } from '../core/contentGenerator.js'; +import { DEFAULT_QWEN_MODEL } from '../config/models.js'; +import { + resolveField, + resolveOptionalField, + layer, + envLayer, + cliSource, + settingsSource, + modelProvidersSource, + defaultSource, + computedSource, + type ConfigSource, + type ConfigSources, + type ConfigLayer, +} from '../utils/configResolver.js'; +import { + AUTH_ENV_MAPPINGS, + DEFAULT_MODELS, + QWEN_OAUTH_ALLOWED_MODELS, + MODEL_GENERATION_CONFIG_FIELDS, +} from './constants.js'; +import type { ResolvedModelConfig } from './types.js'; +export { + validateModelConfig, + type ModelConfigValidationResult, +} from '../core/contentGenerator.js'; + +/** + * CLI-provided configuration values + */ +export interface ModelConfigCliInput { + model?: string; + apiKey?: string; + baseUrl?: string; +} + +/** + * Settings-provided configuration values + */ +export interface ModelConfigSettingsInput { + /** Model name from settings.model.name */ + model?: string; + /** API key from settings.security.auth.apiKey */ + apiKey?: string; + /** Base URL from settings.security.auth.baseUrl */ + baseUrl?: string; + /** Generation config from settings.model.generationConfig */ + generationConfig?: Partial; +} + +/** + * All input sources for model configuration resolution + */ +export interface ModelConfigSourcesInput { + /** Authentication type */ + authType: AuthType; + + /** CLI arguments (highest priority for user-provided values) */ + cli?: ModelConfigCliInput; + + /** Settings file configuration */ + settings?: ModelConfigSettingsInput; + + /** Environment variables (injected for testability) */ + env: Record; + + /** Resolved model from ModelProviders (explicit selection, highest priority) */ + modelProvider?: ResolvedModelConfig; + + /** Proxy URL (computed from Config) */ + proxy?: string; +} + +/** + * Result of model configuration resolution + */ +export interface ModelConfigResolutionResult { + /** The fully resolved configuration */ + config: ContentGeneratorConfig; + /** Source attribution for each field */ + sources: ConfigSources; + /** Warnings generated during resolution */ + warnings: string[]; +} + +/** + * Resolve model configuration from all input sources. + * + * This is the single entry point for model configuration resolution. + * It replaces the duplicate logic in: + * - packages/cli/src/utils/modelProviderUtils.ts (resolveCliGenerationConfig) + * - packages/core/src/core/contentGenerator.ts (resolveContentGeneratorConfigWithSources) + * + * @param input - All configuration sources + * @returns Resolved configuration with source tracking + */ +export function resolveModelConfig( + input: ModelConfigSourcesInput, +): ModelConfigResolutionResult { + const { authType, cli, settings, env, modelProvider, proxy } = input; + const warnings: string[] = []; + const sources: ConfigSources = {}; + + // Special handling for Qwen OAuth + if (authType === AuthType.QWEN_OAUTH) { + return resolveQwenOAuthConfig(input, warnings); + } + + // Get auth-specific env var mappings + const envMapping = + AUTH_ENV_MAPPINGS[authType] || AUTH_ENV_MAPPINGS[AuthType.USE_OPENAI]; + + // Build layers for each field in priority order + // Priority: modelProvider > cli > env > settings > default + + // ---- Model ---- + const modelLayers: Array> = []; + + if (modelProvider) { + modelLayers.push( + layer( + modelProvider.id, + modelProvidersSource(authType, modelProvider.id, 'model.id'), + ), + ); + } + if (cli?.model) { + modelLayers.push(layer(cli.model, cliSource('--model'))); + } + for (const envKey of envMapping.model) { + modelLayers.push(envLayer(env, envKey)); + } + if (settings?.model) { + modelLayers.push(layer(settings.model, settingsSource('model.name'))); + } + + const defaultModel = DEFAULT_MODELS[authType] || ''; + const modelResult = resolveField( + modelLayers, + defaultModel, + defaultSource(defaultModel), + ); + sources['model'] = modelResult.source; + + // ---- API Key ---- + const apiKeyLayers: Array> = []; + + // For modelProvider, read from the specified envKey + if (modelProvider?.envKey) { + const apiKeyFromEnv = env[modelProvider.envKey]; + if (apiKeyFromEnv) { + apiKeyLayers.push( + layer(apiKeyFromEnv, { + kind: 'env', + envKey: modelProvider.envKey, + via: modelProvidersSource(authType, modelProvider.id, 'envKey'), + }), + ); + } + } + if (cli?.apiKey) { + apiKeyLayers.push(layer(cli.apiKey, cliSource('--openaiApiKey'))); + } + for (const envKey of envMapping.apiKey) { + apiKeyLayers.push(envLayer(env, envKey)); + } + if (settings?.apiKey) { + apiKeyLayers.push( + layer(settings.apiKey, settingsSource('security.auth.apiKey')), + ); + } + + const apiKeyResult = resolveOptionalField(apiKeyLayers); + if (apiKeyResult) { + sources['apiKey'] = apiKeyResult.source; + } + + // ---- Base URL ---- + const baseUrlLayers: Array> = []; + + if (modelProvider?.baseUrl) { + baseUrlLayers.push( + layer( + modelProvider.baseUrl, + modelProvidersSource(authType, modelProvider.id, 'baseUrl'), + ), + ); + } + if (cli?.baseUrl) { + baseUrlLayers.push(layer(cli.baseUrl, cliSource('--openaiBaseUrl'))); + } + for (const envKey of envMapping.baseUrl) { + baseUrlLayers.push(envLayer(env, envKey)); + } + if (settings?.baseUrl) { + baseUrlLayers.push( + layer(settings.baseUrl, settingsSource('security.auth.baseUrl')), + ); + } + + const baseUrlResult = resolveOptionalField(baseUrlLayers); + if (baseUrlResult) { + sources['baseUrl'] = baseUrlResult.source; + } + + // ---- API Key Env Key (for error messages) ---- + let apiKeyEnvKey: string | undefined; + if (modelProvider?.envKey) { + apiKeyEnvKey = modelProvider.envKey; + sources['apiKeyEnvKey'] = modelProvidersSource( + authType, + modelProvider.id, + 'envKey', + ); + } + + // ---- Generation Config (from settings or modelProvider) ---- + const generationConfig = resolveGenerationConfig( + settings?.generationConfig, + modelProvider?.generationConfig, + authType, + modelProvider?.id, + sources, + ); + + // Build final config + const config: ContentGeneratorConfig = { + authType, + model: modelResult.value, + apiKey: apiKeyResult?.value, + apiKeyEnvKey, + baseUrl: baseUrlResult?.value, + proxy, + ...generationConfig, + }; + + // Add proxy source + if (proxy) { + sources['proxy'] = computedSource('Config.getProxy()'); + } + + // Add authType source + sources['authType'] = computedSource('provided by caller'); + + return { config, sources, warnings }; +} + +/** + * Special resolver for Qwen OAuth authentication. + * Qwen OAuth has fixed model options and uses dynamic tokens. + */ +function resolveQwenOAuthConfig( + input: ModelConfigSourcesInput, + warnings: string[], +): ModelConfigResolutionResult { + const { cli, settings, proxy } = input; + const sources: ConfigSources = {}; + + // Qwen OAuth only allows specific models + const allowedModels = new Set(QWEN_OAUTH_ALLOWED_MODELS); + + // Determine requested model + const requestedModel = cli?.model || settings?.model; + let resolvedModel: string; + let modelSource: ConfigSource; + + if (requestedModel && allowedModels.has(requestedModel)) { + resolvedModel = requestedModel; + modelSource = cli?.model + ? cliSource('--model') + : settingsSource('model.name'); + } else { + if (requestedModel) { + warnings.push( + `Unsupported Qwen OAuth model '${requestedModel}', falling back to '${DEFAULT_QWEN_MODEL}'.`, + ); + } + resolvedModel = DEFAULT_QWEN_MODEL; + modelSource = defaultSource(`fallback to '${DEFAULT_QWEN_MODEL}'`); + } + + sources['model'] = modelSource; + sources['apiKey'] = computedSource('Qwen OAuth dynamic token'); + sources['authType'] = computedSource('provided by caller'); + + if (proxy) { + sources['proxy'] = computedSource('Config.getProxy()'); + } + + // Resolve generation config from settings + const generationConfig = resolveGenerationConfig( + settings?.generationConfig, + undefined, + AuthType.QWEN_OAUTH, + resolvedModel, + sources, + ); + + const config: ContentGeneratorConfig = { + authType: AuthType.QWEN_OAUTH, + model: resolvedModel, + apiKey: 'QWEN_OAUTH_DYNAMIC_TOKEN', + proxy, + ...generationConfig, + }; + + return { config, sources, warnings }; +} + +/** + * Resolve generation config fields (samplingParams, timeout, etc.) + */ +function resolveGenerationConfig( + settingsConfig: Partial | undefined, + modelProviderConfig: Partial | undefined, + authType: AuthType, + modelId: string | undefined, + sources: ConfigSources, +): Partial { + const result: Partial = {}; + + for (const field of MODEL_GENERATION_CONFIG_FIELDS) { + // ModelProvider config takes priority + if (modelProviderConfig && field in modelProviderConfig) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (result as any)[field] = modelProviderConfig[field]; + sources[field] = modelProvidersSource( + authType, + modelId || '', + `generationConfig.${field}`, + ); + } else if (settingsConfig && field in settingsConfig) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (result as any)[field] = settingsConfig[field]; + sources[field] = settingsSource(`model.generationConfig.${field}`); + } + } + + return result; +} diff --git a/packages/core/src/models/modelRegistry.test.ts b/packages/core/src/models/modelRegistry.test.ts new file mode 100644 index 000000000..b2225425c --- /dev/null +++ b/packages/core/src/models/modelRegistry.test.ts @@ -0,0 +1,390 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { ModelRegistry, QWEN_OAUTH_MODELS } from './modelRegistry.js'; +import { AuthType } from '../core/contentGenerator.js'; +import type { ModelProvidersConfig } from './types.js'; + +describe('ModelRegistry', () => { + describe('initialization', () => { + it('should always include hard-coded qwen-oauth models', () => { + const registry = new ModelRegistry(); + + const qwenModels = registry.getModelsForAuthType(AuthType.QWEN_OAUTH); + expect(qwenModels.length).toBe(QWEN_OAUTH_MODELS.length); + expect(qwenModels[0].id).toBe('coder-model'); + expect(qwenModels[1].id).toBe('vision-model'); + }); + + it('should initialize with empty config', () => { + const registry = new ModelRegistry(); + expect(registry.getModelsForAuthType(AuthType.QWEN_OAUTH).length).toBe( + QWEN_OAUTH_MODELS.length, + ); + expect(registry.getModelsForAuthType(AuthType.USE_OPENAI).length).toBe(0); + }); + + it('should initialize with custom models config', () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'gpt-4-turbo', + name: 'GPT-4 Turbo', + baseUrl: 'https://api.openai.com/v1', + }, + ], + }; + + const registry = new ModelRegistry(modelProvidersConfig); + + const openaiModels = registry.getModelsForAuthType(AuthType.USE_OPENAI); + expect(openaiModels.length).toBe(1); + expect(openaiModels[0].id).toBe('gpt-4-turbo'); + }); + + it('should ignore qwen-oauth models in config (hard-coded)', () => { + const modelProvidersConfig: ModelProvidersConfig = { + 'qwen-oauth': [ + { + id: 'custom-qwen', + name: 'Custom Qwen', + }, + ], + }; + + const registry = new ModelRegistry(modelProvidersConfig); + + // Should still use hard-coded qwen-oauth models + const qwenModels = registry.getModelsForAuthType(AuthType.QWEN_OAUTH); + expect(qwenModels.length).toBe(QWEN_OAUTH_MODELS.length); + expect(qwenModels.find((m) => m.id === 'custom-qwen')).toBeUndefined(); + }); + }); + + describe('getModelsForAuthType', () => { + let registry: ModelRegistry; + + beforeEach(() => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'gpt-4-turbo', + name: 'GPT-4 Turbo', + description: 'Most capable GPT-4', + baseUrl: 'https://api.openai.com/v1', + capabilities: { vision: true }, + }, + { + id: 'gpt-3.5-turbo', + name: 'GPT-3.5 Turbo', + capabilities: { vision: false }, + }, + ], + }; + registry = new ModelRegistry(modelProvidersConfig); + }); + + it('should return models for existing authType', () => { + const models = registry.getModelsForAuthType(AuthType.USE_OPENAI); + expect(models.length).toBe(2); + }); + + it('should return empty array for non-existent authType', () => { + const models = registry.getModelsForAuthType(AuthType.USE_VERTEX_AI); + expect(models.length).toBe(0); + }); + + it('should return AvailableModel format with correct fields', () => { + const models = registry.getModelsForAuthType(AuthType.USE_OPENAI); + const gpt4 = models.find((m) => m.id === 'gpt-4-turbo'); + + expect(gpt4).toBeDefined(); + expect(gpt4?.label).toBe('GPT-4 Turbo'); + expect(gpt4?.description).toBe('Most capable GPT-4'); + expect(gpt4?.isVision).toBe(true); + expect(gpt4?.authType).toBe(AuthType.USE_OPENAI); + }); + }); + + describe('getModel', () => { + let registry: ModelRegistry; + + beforeEach(() => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'gpt-4-turbo', + name: 'GPT-4 Turbo', + baseUrl: 'https://api.openai.com/v1', + generationConfig: { + samplingParams: { + temperature: 0.8, + max_tokens: 4096, + }, + }, + }, + ], + }; + registry = new ModelRegistry(modelProvidersConfig); + }); + + it('should return resolved model config', () => { + const model = registry.getModel(AuthType.USE_OPENAI, 'gpt-4-turbo'); + + expect(model).toBeDefined(); + expect(model?.id).toBe('gpt-4-turbo'); + expect(model?.name).toBe('GPT-4 Turbo'); + expect(model?.authType).toBe(AuthType.USE_OPENAI); + expect(model?.baseUrl).toBe('https://api.openai.com/v1'); + }); + + it('should preserve generationConfig without applying defaults', () => { + const model = registry.getModel(AuthType.USE_OPENAI, 'gpt-4-turbo'); + + expect(model?.generationConfig.samplingParams?.temperature).toBe(0.8); + expect(model?.generationConfig.samplingParams?.max_tokens).toBe(4096); + // No defaults are applied - only the configured values are present + expect(model?.generationConfig.samplingParams?.top_p).toBeUndefined(); + expect(model?.generationConfig.timeout).toBeUndefined(); + }); + + it('should return undefined for non-existent model', () => { + const model = registry.getModel(AuthType.USE_OPENAI, 'non-existent'); + expect(model).toBeUndefined(); + }); + + it('should return undefined for non-existent authType', () => { + const model = registry.getModel(AuthType.USE_VERTEX_AI, 'some-model'); + expect(model).toBeUndefined(); + }); + }); + + describe('hasModel', () => { + let registry: ModelRegistry; + + beforeEach(() => { + registry = new ModelRegistry({ + openai: [{ id: 'gpt-4', name: 'GPT-4' }], + }); + }); + + it('should return true for existing model', () => { + expect(registry.hasModel(AuthType.USE_OPENAI, 'gpt-4')).toBe(true); + }); + + it('should return false for non-existent model', () => { + expect(registry.hasModel(AuthType.USE_OPENAI, 'non-existent')).toBe( + false, + ); + }); + + it('should return false for non-existent authType', () => { + expect(registry.hasModel(AuthType.USE_VERTEX_AI, 'gpt-4')).toBe(false); + }); + }); + + describe('getDefaultModelForAuthType', () => { + it('should return coder-model for qwen-oauth', () => { + const registry = new ModelRegistry(); + const defaultModel = registry.getDefaultModelForAuthType( + AuthType.QWEN_OAUTH, + ); + expect(defaultModel?.id).toBe('coder-model'); + }); + + it('should return first model for other authTypes', () => { + const registry = new ModelRegistry({ + openai: [ + { id: 'gpt-4', name: 'GPT-4' }, + { id: 'gpt-3.5', name: 'GPT-3.5' }, + ], + }); + + const defaultModel = registry.getDefaultModelForAuthType( + AuthType.USE_OPENAI, + ); + expect(defaultModel?.id).toBe('gpt-4'); + }); + }); + + describe('validation', () => { + it('should throw error for model without id', () => { + expect( + () => + new ModelRegistry({ + openai: [{ id: '', name: 'No ID' }], + }), + ).toThrow('missing required field: id'); + }); + }); + + describe('default base URLs', () => { + it('should apply default dashscope URL for qwen-oauth', () => { + const registry = new ModelRegistry(); + const model = registry.getModel(AuthType.QWEN_OAUTH, 'coder-model'); + expect(model?.baseUrl).toBe( + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + ); + }); + + it('should apply default openai URL when not specified', () => { + const registry = new ModelRegistry({ + openai: [{ id: 'gpt-4', name: 'GPT-4' }], + }); + + const model = registry.getModel(AuthType.USE_OPENAI, 'gpt-4'); + expect(model?.baseUrl).toBe('https://api.openai.com/v1'); + }); + + it('should use custom baseUrl when specified', () => { + const registry = new ModelRegistry({ + openai: [ + { + id: 'deepseek', + name: 'DeepSeek', + baseUrl: 'https://api.deepseek.com/v1', + }, + ], + }); + + const model = registry.getModel(AuthType.USE_OPENAI, 'deepseek'); + expect(model?.baseUrl).toBe('https://api.deepseek.com/v1'); + }); + }); + + describe('authType key validation', () => { + it('should accept valid authType keys', () => { + const registry = new ModelRegistry({ + openai: [{ id: 'gpt-4', name: 'GPT-4' }], + gemini: [{ id: 'gemini-pro', name: 'Gemini Pro' }], + }); + + const openaiModels = registry.getModelsForAuthType(AuthType.USE_OPENAI); + expect(openaiModels.length).toBe(1); + expect(openaiModels[0].id).toBe('gpt-4'); + + const geminiModels = registry.getModelsForAuthType(AuthType.USE_GEMINI); + expect(geminiModels.length).toBe(1); + expect(geminiModels[0].id).toBe('gemini-pro'); + }); + + it('should skip invalid authType keys with warning', () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const registry = new ModelRegistry({ + openai: [{ id: 'gpt-4', name: 'GPT-4' }], + 'invalid-key': [{ id: 'some-model', name: 'Some Model' }], + } as unknown as ModelProvidersConfig); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('[ModelRegistry] Invalid authType key'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('invalid-key'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Expected one of:'), + ); + + // Valid key should be registered + expect(registry.getModelsForAuthType(AuthType.USE_OPENAI).length).toBe(1); + + // Invalid key should be skipped (no crash) + const openaiModels = registry.getModelsForAuthType(AuthType.USE_OPENAI); + expect(openaiModels.length).toBe(1); + + consoleWarnSpy.mockRestore(); + }); + + it('should handle mixed valid and invalid keys', () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const registry = new ModelRegistry({ + openai: [{ id: 'gpt-4', name: 'GPT-4' }], + 'bad-key-1': [{ id: 'model-1', name: 'Model 1' }], + gemini: [{ id: 'gemini-pro', name: 'Gemini Pro' }], + 'bad-key-2': [{ id: 'model-2', name: 'Model 2' }], + } as unknown as ModelProvidersConfig); + + // Should warn twice for the two invalid keys + expect(consoleWarnSpy).toHaveBeenCalledTimes(2); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('bad-key-1'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('bad-key-2'), + ); + + // Valid keys should be registered + expect(registry.getModelsForAuthType(AuthType.USE_OPENAI).length).toBe(1); + expect(registry.getModelsForAuthType(AuthType.USE_GEMINI).length).toBe(1); + + // Invalid keys should be skipped + const openaiModels = registry.getModelsForAuthType(AuthType.USE_OPENAI); + expect(openaiModels.length).toBe(1); + + const geminiModels = registry.getModelsForAuthType(AuthType.USE_GEMINI); + expect(geminiModels.length).toBe(1); + + consoleWarnSpy.mockRestore(); + }); + + it('should list all valid AuthType values in warning message', () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + new ModelRegistry({ + 'invalid-auth': [{ id: 'model', name: 'Model' }], + } as unknown as ModelProvidersConfig); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('openai'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('qwen-oauth'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('gemini'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('vertex-ai'), + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('anthropic'), + ); + + consoleWarnSpy.mockRestore(); + }); + + it('should work correctly with getModelsForAuthType after validation', () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const registry = new ModelRegistry({ + openai: [ + { id: 'gpt-4', name: 'GPT-4' }, + { id: 'gpt-3.5', name: 'GPT-3.5' }, + ], + 'invalid-key': [{ id: 'invalid-model', name: 'Invalid Model' }], + } as unknown as ModelProvidersConfig); + + const models = registry.getModelsForAuthType(AuthType.USE_OPENAI); + expect(models.length).toBe(2); + expect(models.find((m) => m.id === 'gpt-4')).toBeDefined(); + expect(models.find((m) => m.id === 'gpt-3.5')).toBeDefined(); + expect(models.find((m) => m.id === 'invalid-model')).toBeUndefined(); + + consoleWarnSpy.mockRestore(); + }); + }); +}); diff --git a/packages/core/src/models/modelRegistry.ts b/packages/core/src/models/modelRegistry.ts new file mode 100644 index 000000000..cec6ebb94 --- /dev/null +++ b/packages/core/src/models/modelRegistry.ts @@ -0,0 +1,180 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AuthType } from '../core/contentGenerator.js'; +import { DEFAULT_OPENAI_BASE_URL } from '../core/openaiContentGenerator/constants.js'; +import { + type ModelConfig, + type ModelProvidersConfig, + type ResolvedModelConfig, + type AvailableModel, +} from './types.js'; +import { DEFAULT_QWEN_MODEL } from '../config/models.js'; +import { QWEN_OAUTH_MODELS } from './constants.js'; + +export { QWEN_OAUTH_MODELS } from './constants.js'; + +/** + * Validates if a string key is a valid AuthType enum value. + * @param key - The key to validate + * @returns The validated AuthType or undefined if invalid + */ +function validateAuthTypeKey(key: string): AuthType | undefined { + // Check if the key is a valid AuthType enum value + if (Object.values(AuthType).includes(key as AuthType)) { + return key as AuthType; + } + + // Invalid key + return undefined; +} + +/** + * Central registry for managing model configurations. + * Models are organized by authType. + */ +export class ModelRegistry { + private modelsByAuthType: Map>; + + private getDefaultBaseUrl(authType: AuthType): string { + switch (authType) { + case AuthType.QWEN_OAUTH: + return 'DYNAMIC_QWEN_OAUTH_BASE_URL'; + case AuthType.USE_OPENAI: + return DEFAULT_OPENAI_BASE_URL; + default: + return ''; + } + } + + constructor(modelProvidersConfig?: ModelProvidersConfig) { + this.modelsByAuthType = new Map(); + + // Always register qwen-oauth models (hard-coded, cannot be overridden) + this.registerAuthTypeModels(AuthType.QWEN_OAUTH, QWEN_OAUTH_MODELS); + + // Register user-configured models for other authTypes + if (modelProvidersConfig) { + for (const [rawKey, models] of Object.entries(modelProvidersConfig)) { + const authType = validateAuthTypeKey(rawKey); + + if (!authType) { + console.warn( + `[ModelRegistry] Invalid authType key "${rawKey}" in modelProviders config. Expected one of: ${Object.values(AuthType).join(', ')}. Skipping.`, + ); + continue; + } + + // Skip qwen-oauth as it uses hard-coded models + if (authType === AuthType.QWEN_OAUTH) { + continue; + } + + this.registerAuthTypeModels(authType, models); + } + } + } + + /** + * Register models for an authType + */ + private registerAuthTypeModels( + authType: AuthType, + models: ModelConfig[], + ): void { + const modelMap = new Map(); + + for (const config of models) { + const resolved = this.resolveModelConfig(config, authType); + modelMap.set(config.id, resolved); + } + + this.modelsByAuthType.set(authType, modelMap); + } + + /** + * Get all models for a specific authType. + * This is used by /model command to show only relevant models. + */ + getModelsForAuthType(authType: AuthType): AvailableModel[] { + const models = this.modelsByAuthType.get(authType); + if (!models) return []; + + return Array.from(models.values()).map((model) => ({ + id: model.id, + label: model.name, + description: model.description, + capabilities: model.capabilities, + authType: model.authType, + isVision: model.capabilities?.vision ?? false, + })); + } + + /** + * Get model configuration by authType and modelId + */ + getModel( + authType: AuthType, + modelId: string, + ): ResolvedModelConfig | undefined { + const models = this.modelsByAuthType.get(authType); + return models?.get(modelId); + } + + /** + * Check if model exists for given authType + */ + hasModel(authType: AuthType, modelId: string): boolean { + const models = this.modelsByAuthType.get(authType); + return models?.has(modelId) ?? false; + } + + /** + * Get default model for an authType. + * For qwen-oauth, returns the coder model. + * For others, returns the first configured model. + */ + getDefaultModelForAuthType( + authType: AuthType, + ): ResolvedModelConfig | undefined { + if (authType === AuthType.QWEN_OAUTH) { + return this.getModel(authType, DEFAULT_QWEN_MODEL); + } + const models = this.modelsByAuthType.get(authType); + if (!models || models.size === 0) return undefined; + return Array.from(models.values())[0]; + } + + /** + * Resolve model config by applying defaults + */ + private resolveModelConfig( + config: ModelConfig, + authType: AuthType, + ): ResolvedModelConfig { + this.validateModelConfig(config, authType); + + return { + ...config, + authType, + name: config.name || config.id, + baseUrl: config.baseUrl || this.getDefaultBaseUrl(authType), + generationConfig: config.generationConfig ?? {}, + capabilities: config.capabilities || {}, + }; + } + + /** + * Validate model configuration + */ + private validateModelConfig(config: ModelConfig, authType: AuthType): void { + if (!config.id) { + throw new Error( + `Model config in authType '${authType}' missing required field: id`, + ); + } + } +} diff --git a/packages/core/src/models/modelsConfig.test.ts b/packages/core/src/models/modelsConfig.test.ts new file mode 100644 index 000000000..51c54ea59 --- /dev/null +++ b/packages/core/src/models/modelsConfig.test.ts @@ -0,0 +1,451 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { ModelsConfig } from './modelsConfig.js'; +import { AuthType } from '../core/contentGenerator.js'; +import type { ModelProvidersConfig } from './types.js'; + +describe('ModelsConfig', () => { + function deepClone(value: T): T { + if (value === null || typeof value !== 'object') return value; + if (Array.isArray(value)) return value.map((v) => deepClone(v)) as T; + const out: Record = {}; + for (const key of Object.keys(value as Record)) { + out[key] = deepClone((value as Record)[key]); + } + return out as T; + } + + it('should fully rollback state when switchModel fails after applying defaults (authType change)', async () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'openai-a', + name: 'OpenAI A', + baseUrl: 'https://api.openai.example.com/v1', + envKey: 'OPENAI_API_KEY', + generationConfig: { + samplingParams: { temperature: 0.2, max_tokens: 123 }, + timeout: 111, + maxRetries: 1, + }, + }, + ], + anthropic: [ + { + id: 'anthropic-b', + name: 'Anthropic B', + baseUrl: 'https://api.anthropic.example.com/v1', + envKey: 'ANTHROPIC_API_KEY', + generationConfig: { + samplingParams: { temperature: 0.7, max_tokens: 456 }, + timeout: 222, + maxRetries: 2, + }, + }, + ], + }; + + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + modelProvidersConfig, + }); + + // Establish a known baseline state via a successful switch. + await modelsConfig.switchModel(AuthType.USE_OPENAI, 'openai-a'); + const baselineAuthType = modelsConfig.getCurrentAuthType(); + const baselineModel = modelsConfig.getModel(); + const baselineStrict = modelsConfig.isStrictModelProviderSelection(); + const baselineGc = deepClone(modelsConfig.getGenerationConfig()); + const baselineSources = deepClone( + modelsConfig.getGenerationConfigSources(), + ); + + modelsConfig.setOnModelChange(async () => { + throw new Error('refresh failed'); + }); + + await expect( + modelsConfig.switchModel(AuthType.USE_ANTHROPIC, 'anthropic-b'), + ).rejects.toThrow('refresh failed'); + + // Ensure state is fully rolled back (selection + generation config + flags). + expect(modelsConfig.getCurrentAuthType()).toBe(baselineAuthType); + expect(modelsConfig.getModel()).toBe(baselineModel); + expect(modelsConfig.isStrictModelProviderSelection()).toBe(baselineStrict); + + const gc = modelsConfig.getGenerationConfig(); + expect(gc).toMatchObject({ + model: baselineGc.model, + baseUrl: baselineGc.baseUrl, + apiKeyEnvKey: baselineGc.apiKeyEnvKey, + samplingParams: baselineGc.samplingParams, + timeout: baselineGc.timeout, + maxRetries: baselineGc.maxRetries, + }); + + const sources = modelsConfig.getGenerationConfigSources(); + expect(sources).toEqual(baselineSources); + }); + + it('should fully rollback state when switchModel fails after applying defaults', async () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'model-a', + name: 'Model A', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_A', + }, + { + id: 'model-b', + name: 'Model B', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_B', + }, + ], + }; + + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + modelProvidersConfig, + }); + + await modelsConfig.switchModel(AuthType.USE_OPENAI, 'model-a'); + const baselineModel = modelsConfig.getModel(); + const baselineGc = deepClone(modelsConfig.getGenerationConfig()); + const baselineSources = deepClone( + modelsConfig.getGenerationConfigSources(), + ); + + modelsConfig.setOnModelChange(async () => { + throw new Error('hot-update failed'); + }); + + await expect( + modelsConfig.switchModel(AuthType.USE_OPENAI, 'model-b'), + ).rejects.toThrow('hot-update failed'); + + expect(modelsConfig.getModel()).toBe(baselineModel); + expect(modelsConfig.getGenerationConfig()).toMatchObject({ + model: baselineGc.model, + baseUrl: baselineGc.baseUrl, + apiKeyEnvKey: baselineGc.apiKeyEnvKey, + }); + expect(modelsConfig.getGenerationConfigSources()).toEqual(baselineSources); + }); + + it('should preserve an explicit apiKey when switching models if envKey is missing in the environment', async () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'model-a', + name: 'Model A', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_SHARED', + }, + { + id: 'model-b', + name: 'Model B', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_SHARED', + }, + ], + }; + + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + initialModelId: 'model-a', + modelProvidersConfig, + }); + + // Simulate key prompt flow / explicit key provided via CLI/settings. + modelsConfig.updateCredentials({ apiKey: 'manual-key', model: 'model-a' }); + + await modelsConfig.switchModel(AuthType.USE_OPENAI, 'model-b'); + + const gc = modelsConfig.getGenerationConfig(); + expect(gc.model).toBe('model-b'); + expect(gc.apiKey).toBe('manual-key'); + expect(gc.apiKeyEnvKey).toBe('API_KEY_SHARED'); + }); + + it('should preserve settings generationConfig when model is updated via updateCredentials even if it matches modelProviders', () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'model-a', + name: 'Model A', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_A', + generationConfig: { + samplingParams: { temperature: 0.1, max_tokens: 123 }, + timeout: 111, + maxRetries: 1, + }, + }, + ], + }; + + // Simulate settings.model.generationConfig being resolved into ModelsConfig.generationConfig + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + initialModelId: 'model-a', + modelProvidersConfig, + generationConfig: { + model: 'model-a', + samplingParams: { temperature: 0.9, max_tokens: 999 }, + timeout: 9999, + maxRetries: 9, + }, + generationConfigSources: { + model: { kind: 'settings', detail: 'settings.model.name' }, + samplingParams: { + kind: 'settings', + detail: 'settings.model.generationConfig.samplingParams', + }, + timeout: { + kind: 'settings', + detail: 'settings.model.generationConfig.timeout', + }, + maxRetries: { + kind: 'settings', + detail: 'settings.model.generationConfig.maxRetries', + }, + }, + }); + + // User manually updates the model via updateCredentials (e.g. key prompt flow). + // Even if the model ID matches a modelProviders entry, we must not apply provider defaults + // that would overwrite settings.model.generationConfig. + modelsConfig.updateCredentials({ model: 'model-a' }); + + modelsConfig.syncAfterAuthRefresh( + AuthType.USE_OPENAI, + modelsConfig.getModel(), + ); + + const gc = modelsConfig.getGenerationConfig(); + expect(gc.model).toBe('model-a'); + expect(gc.samplingParams?.temperature).toBe(0.9); + expect(gc.samplingParams?.max_tokens).toBe(999); + expect(gc.timeout).toBe(9999); + expect(gc.maxRetries).toBe(9); + }); + + it('should preserve settings generationConfig across multiple auth refreshes after updateCredentials', () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'model-a', + name: 'Model A', + baseUrl: 'https://api.example.com/v1', + envKey: 'API_KEY_A', + generationConfig: { + samplingParams: { temperature: 0.1, max_tokens: 123 }, + timeout: 111, + maxRetries: 1, + }, + }, + ], + }; + + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + initialModelId: 'model-a', + modelProvidersConfig, + generationConfig: { + model: 'model-a', + samplingParams: { temperature: 0.9, max_tokens: 999 }, + timeout: 9999, + maxRetries: 9, + }, + generationConfigSources: { + model: { kind: 'settings', detail: 'settings.model.name' }, + samplingParams: { + kind: 'settings', + detail: 'settings.model.generationConfig.samplingParams', + }, + timeout: { + kind: 'settings', + detail: 'settings.model.generationConfig.timeout', + }, + maxRetries: { + kind: 'settings', + detail: 'settings.model.generationConfig.maxRetries', + }, + }, + }); + + modelsConfig.updateCredentials({ + apiKey: 'manual-key', + baseUrl: 'https://manual.example.com/v1', + model: 'model-a', + }); + + // First auth refresh + modelsConfig.syncAfterAuthRefresh( + AuthType.USE_OPENAI, + modelsConfig.getModel(), + ); + // Second auth refresh should still preserve settings generationConfig + modelsConfig.syncAfterAuthRefresh( + AuthType.USE_OPENAI, + modelsConfig.getModel(), + ); + + const gc = modelsConfig.getGenerationConfig(); + expect(gc.model).toBe('model-a'); + expect(gc.samplingParams?.temperature).toBe(0.9); + expect(gc.samplingParams?.max_tokens).toBe(999); + expect(gc.timeout).toBe(9999); + expect(gc.maxRetries).toBe(9); + }); + + it('should clear provider-sourced config when updateCredentials is called after switchModel', async () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'provider-model', + name: 'Provider Model', + baseUrl: 'https://provider.example.com/v1', + envKey: 'PROVIDER_API_KEY', + generationConfig: { + samplingParams: { temperature: 0.1, max_tokens: 100 }, + timeout: 1000, + maxRetries: 2, + }, + }, + ], + }; + + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + modelProvidersConfig, + }); + + // Step 1: Switch to a provider model - this applies provider config + await modelsConfig.switchModel(AuthType.USE_OPENAI, 'provider-model'); + + // Verify provider config is applied + let gc = modelsConfig.getGenerationConfig(); + expect(gc.model).toBe('provider-model'); + expect(gc.baseUrl).toBe('https://provider.example.com/v1'); + expect(gc.samplingParams?.temperature).toBe(0.1); + expect(gc.samplingParams?.max_tokens).toBe(100); + expect(gc.timeout).toBe(1000); + expect(gc.maxRetries).toBe(2); + + // Verify sources are from modelProviders + let sources = modelsConfig.getGenerationConfigSources(); + expect(sources['model']?.kind).toBe('modelProviders'); + expect(sources['baseUrl']?.kind).toBe('modelProviders'); + expect(sources['samplingParams']?.kind).toBe('modelProviders'); + expect(sources['timeout']?.kind).toBe('modelProviders'); + expect(sources['maxRetries']?.kind).toBe('modelProviders'); + + // Step 2: User manually sets credentials via updateCredentials + // This should clear all provider-sourced config + modelsConfig.updateCredentials({ + apiKey: 'manual-api-key', + model: 'custom-model', + }); + + // Verify provider-sourced config is cleared + gc = modelsConfig.getGenerationConfig(); + expect(gc.model).toBe('custom-model'); // Set by updateCredentials + expect(gc.apiKey).toBe('manual-api-key'); // Set by updateCredentials + expect(gc.baseUrl).toBeUndefined(); // Cleared (was from provider) + expect(gc.samplingParams).toBeUndefined(); // Cleared (was from provider) + expect(gc.timeout).toBeUndefined(); // Cleared (was from provider) + expect(gc.maxRetries).toBeUndefined(); // Cleared (was from provider) + + // Verify sources are updated + sources = modelsConfig.getGenerationConfigSources(); + expect(sources['model']?.kind).toBe('programmatic'); + expect(sources['apiKey']?.kind).toBe('programmatic'); + expect(sources['baseUrl']).toBeUndefined(); // Source cleared + expect(sources['samplingParams']).toBeUndefined(); // Source cleared + expect(sources['timeout']).toBeUndefined(); // Source cleared + expect(sources['maxRetries']).toBeUndefined(); // Source cleared + }); + + it('should preserve non-provider config when updateCredentials clears provider config', async () => { + const modelProvidersConfig: ModelProvidersConfig = { + openai: [ + { + id: 'provider-model', + name: 'Provider Model', + baseUrl: 'https://provider.example.com/v1', + envKey: 'PROVIDER_API_KEY', + generationConfig: { + samplingParams: { temperature: 0.1, max_tokens: 100 }, + timeout: 1000, + maxRetries: 2, + }, + }, + ], + }; + + // Initialize with settings-sourced config + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.USE_OPENAI, + modelProvidersConfig, + generationConfig: { + samplingParams: { temperature: 0.8, max_tokens: 500 }, + timeout: 5000, + }, + generationConfigSources: { + samplingParams: { + kind: 'settings', + detail: 'settings.model.generationConfig.samplingParams', + }, + timeout: { + kind: 'settings', + detail: 'settings.model.generationConfig.timeout', + }, + }, + }); + + // Switch to provider model - this overwrites with provider config + await modelsConfig.switchModel(AuthType.USE_OPENAI, 'provider-model'); + + // Verify provider config is applied (overwriting settings) + let gc = modelsConfig.getGenerationConfig(); + expect(gc.samplingParams?.temperature).toBe(0.1); + expect(gc.timeout).toBe(1000); + + // User manually sets credentials - clears provider-sourced config + modelsConfig.updateCredentials({ + apiKey: 'manual-key', + }); + + // Provider-sourced config should be cleared + gc = modelsConfig.getGenerationConfig(); + expect(gc.samplingParams).toBeUndefined(); + expect(gc.timeout).toBeUndefined(); + // The original settings-sourced config is NOT restored automatically; + // it should be re-resolved by other layers in refreshAuth + }); + + it('should always force Qwen OAuth apiKey placeholder when applying model defaults', async () => { + // Simulate a stale/explicit apiKey existing before switching models. + const modelsConfig = new ModelsConfig({ + initialAuthType: AuthType.QWEN_OAUTH, + generationConfig: { + apiKey: 'manual-key-should-not-leak', + }, + }); + + // Switching within qwen-oauth triggers applyResolvedModelDefaults(). + await modelsConfig.switchModel(AuthType.QWEN_OAUTH, 'vision-model'); + + const gc = modelsConfig.getGenerationConfig(); + expect(gc.apiKey).toBe('QWEN_OAUTH_DYNAMIC_TOKEN'); + expect(gc.apiKeyEnvKey).toBeUndefined(); + }); +}); diff --git a/packages/core/src/models/modelsConfig.ts b/packages/core/src/models/modelsConfig.ts new file mode 100644 index 000000000..022737074 --- /dev/null +++ b/packages/core/src/models/modelsConfig.ts @@ -0,0 +1,697 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import process from 'node:process'; + +import { AuthType } from '../core/contentGenerator.js'; +import type { ContentGeneratorConfig } from '../core/contentGenerator.js'; +import type { ContentGeneratorConfigSources } from '../core/contentGenerator.js'; +import { DEFAULT_QWEN_MODEL } from '../config/models.js'; + +import { ModelRegistry } from './modelRegistry.js'; +import { + type ModelProvidersConfig, + type ResolvedModelConfig, + type AvailableModel, + type ModelSwitchMetadata, +} from './types.js'; +import { + MODEL_GENERATION_CONFIG_FIELDS, + CREDENTIAL_FIELDS, + PROVIDER_SOURCED_FIELDS, +} from './constants.js'; + +export { + MODEL_GENERATION_CONFIG_FIELDS, + CREDENTIAL_FIELDS, + PROVIDER_SOURCED_FIELDS, +}; + +/** + * Callback for when the model changes. + * Used by Config to refresh auth/ContentGenerator when needed. + */ +export type OnModelChangeCallback = ( + authType: AuthType, + requiresRefresh: boolean, +) => Promise; + +/** + * Options for creating ModelsConfig + */ +export interface ModelsConfigOptions { + /** Initial authType from settings */ + initialAuthType?: AuthType; + /** Initial model ID from settings */ + initialModelId?: string; + /** Model providers configuration */ + modelProvidersConfig?: ModelProvidersConfig; + /** Generation config from CLI/settings */ + generationConfig?: Partial; + /** Source tracking for generation config */ + generationConfigSources?: ContentGeneratorConfigSources; + /** Callback when model changes require refresh */ + onModelChange?: OnModelChangeCallback; +} + +/** + * ModelsConfig manages all model selection logic and state. + * + * This class encapsulates: + * - ModelRegistry for model configuration storage + * - Current authType and modelId selection + * - Generation config management + * - Model switching logic + * + * Config uses this as a thin entry point for all model-related operations. + */ +export class ModelsConfig { + private readonly modelRegistry: ModelRegistry; + + // Current selection state + private currentAuthType: AuthType; + private currentModelId: string; + + // Generation config state + private _generationConfig: Partial; + private generationConfigSources: ContentGeneratorConfigSources; + + // Flag for strict model provider selection + private strictModelProviderSelection: boolean = false; + + // One-shot flag for qwen-oauth credential caching + private requireCachedQwenCredentialsOnce: boolean = false; + + // One-shot flag indicating credentials were manually set via updateCredentials() + // When true, syncAfterAuthRefresh should NOT override these credentials with + // modelProviders defaults (even if the model ID matches a registry entry). + // + // This must be persistent across auth refreshes, because refreshAuth() can be + // triggered multiple times after a credential prompt flow. We only clear this + // flag when we explicitly apply modelProvider defaults (i.e. when the user + // switches to a registry model via switchModel). + private hasManualCredentials: boolean = false; + + // Callback for notifying Config of model changes + private onModelChange?: OnModelChangeCallback; + + // Flag indicating whether authType was explicitly provided (not defaulted) + private readonly authTypeWasExplicitlyProvided: boolean; + + private static deepClone(value: T): T { + if (value === null || typeof value !== 'object') { + return value; + } + if (Array.isArray(value)) { + return value.map((v) => ModelsConfig.deepClone(v)) as T; + } + const out: Record = {}; + for (const key of Object.keys(value as Record)) { + out[key] = ModelsConfig.deepClone( + (value as Record)[key], + ); + } + return out as T; + } + + private snapshotState(): { + currentAuthType: AuthType; + currentModelId: string; + generationConfig: Partial; + generationConfigSources: ContentGeneratorConfigSources; + strictModelProviderSelection: boolean; + requireCachedQwenCredentialsOnce: boolean; + hasManualCredentials: boolean; + } { + return { + currentAuthType: this.currentAuthType, + currentModelId: this.currentModelId, + generationConfig: ModelsConfig.deepClone(this._generationConfig), + generationConfigSources: ModelsConfig.deepClone( + this.generationConfigSources, + ), + strictModelProviderSelection: this.strictModelProviderSelection, + requireCachedQwenCredentialsOnce: this.requireCachedQwenCredentialsOnce, + hasManualCredentials: this.hasManualCredentials, + }; + } + + private restoreState( + snapshot: ReturnType, + ): void { + this.currentAuthType = snapshot.currentAuthType; + this.currentModelId = snapshot.currentModelId; + this._generationConfig = snapshot.generationConfig; + this.generationConfigSources = snapshot.generationConfigSources; + this.strictModelProviderSelection = snapshot.strictModelProviderSelection; + this.requireCachedQwenCredentialsOnce = + snapshot.requireCachedQwenCredentialsOnce; + this.hasManualCredentials = snapshot.hasManualCredentials; + } + + constructor(options: ModelsConfigOptions = {}) { + this.modelRegistry = new ModelRegistry(options.modelProvidersConfig); + this.onModelChange = options.onModelChange; + + // Initialize generation config + this._generationConfig = { + model: options.initialModelId, + ...(options.generationConfig || {}), + }; + this.generationConfigSources = options.generationConfigSources || {}; + + // Track if authType was explicitly provided + this.authTypeWasExplicitlyProvided = options.initialAuthType !== undefined; + + // Initialize selection state + this.currentAuthType = options.initialAuthType || AuthType.QWEN_OAUTH; + this.currentModelId = options.initialModelId || ''; + + // Validate and initialize default selection + this.initializeDefaultSelection(); + } + + /** + * Initialize default selection based on settings/environment. + * + * Note: The generationConfig passed to ModelsConfig should already be fully + * resolved by ModelConfigResolver, which handles CLI args, env vars, and settings. + * This method primarily validates and sets up internal state. + */ + private initializeDefaultSelection(): void { + // If generationConfig already has a model (resolved by ModelConfigResolver), + // use that as the current selection + if (this._generationConfig.model) { + this.currentModelId = this._generationConfig.model; + return; + } + + // Check if persisted model selection is valid + if ( + this.currentModelId && + this.modelRegistry.hasModel(this.currentAuthType, this.currentModelId) + ) { + return; + } + + // Use registry default + const defaultModel = this.modelRegistry.getDefaultModelForAuthType( + this.currentAuthType, + ); + if (defaultModel) { + this.currentModelId = defaultModel.id; + if (!this._generationConfig.model) { + this._generationConfig.model = defaultModel.id; + } + } + } + + /** + * Get current model ID + */ + getModel(): string { + return ( + this._generationConfig.model || this.currentModelId || DEFAULT_QWEN_MODEL + ); + } + + /** + * Get current authType + */ + getCurrentAuthType(): AuthType { + return this.currentAuthType; + } + + /** + * Check if authType was explicitly provided (via CLI or settings). + * If false, the default QWEN_OAUTH is being used. + */ + wasAuthTypeExplicitlyProvided(): boolean { + return this.authTypeWasExplicitlyProvided; + } + + /** + * Get available models for current authType + */ + getAvailableModels(): AvailableModel[] { + return this.modelRegistry.getModelsForAuthType(this.currentAuthType); + } + + /** + * Get available models for a specific authType + */ + getAvailableModelsForAuthType(authType: AuthType): AvailableModel[] { + return this.modelRegistry.getModelsForAuthType(authType); + } + + /** + * Check if a model exists for the given authType + */ + hasModel(authType: AuthType, modelId: string): boolean { + return this.modelRegistry.hasModel(authType, modelId); + } + + /** + * Set model programmatically (e.g., VLM auto-switch, fallback). + * Supports both registry models and raw model IDs. + */ + async setModel( + newModel: string, + metadata?: ModelSwitchMetadata, + ): Promise { + // Special case: qwen-oauth VLM auto-switch - hot update in place + if ( + this.currentAuthType === AuthType.QWEN_OAUTH && + (newModel === DEFAULT_QWEN_MODEL || newModel === 'vision-model') + ) { + this.strictModelProviderSelection = false; + this._generationConfig.model = newModel; + this.currentModelId = newModel; + this.generationConfigSources['model'] = { + kind: 'programmatic', + detail: metadata?.reason || 'setModel', + }; + return; + } + + // If model exists in registry, use full switch logic + if (this.modelRegistry.hasModel(this.currentAuthType, newModel)) { + await this.switchModel(this.currentAuthType, newModel); + return; + } + + // Raw model override: update generation config in-place + this.strictModelProviderSelection = false; + this._generationConfig.model = newModel; + this.currentModelId = newModel; + this.generationConfigSources['model'] = { + kind: 'programmatic', + detail: metadata?.reason || 'setModel', + }; + } + + /** + * Switch model (and optionally authType) via registry-backed selection. + * This is a superset of the previous split APIs for model-only vs authType+model switching. + */ + async switchModel( + authType: AuthType, + modelId: string, + options?: { requireCachedCredentials?: boolean }, + _metadata?: ModelSwitchMetadata, + ): Promise { + const snapshot = this.snapshotState(); + if (authType === AuthType.QWEN_OAUTH && options?.requireCachedCredentials) { + this.requireCachedQwenCredentialsOnce = true; + } + + try { + const isAuthTypeChange = authType !== this.currentAuthType; + this.currentAuthType = authType; + + const model = this.modelRegistry.getModel(authType, modelId); + if (!model) { + throw new Error( + `Model '${modelId}' not found for authType '${authType}'`, + ); + } + + // Apply model defaults + this.applyResolvedModelDefaults(model); + + // Update selection state + this.currentModelId = modelId; + + const requiresRefresh = isAuthTypeChange + ? true + : this.checkRequiresRefresh(snapshot.currentModelId); + + if (this.onModelChange) { + await this.onModelChange(authType, requiresRefresh); + } + } catch (error) { + // Rollback on error + this.restoreState(snapshot); + throw error; + } + } + + /** + * Get generation config for ContentGenerator creation + */ + getGenerationConfig(): Partial { + return this._generationConfig; + } + + /** + * Get generation config sources for debugging/UI + */ + getGenerationConfigSources(): ContentGeneratorConfigSources { + return this.generationConfigSources; + } + + /** + * Update credentials in generation config. + * Sets a flag to prevent syncAfterAuthRefresh from overriding these credentials. + * + * When credentials are manually set, we clear all provider-sourced configuration + * to maintain provider atomicity (either fully applied or not at all). + * Other layers (CLI, env, settings, defaults) will participate in resolve. + */ + updateCredentials(credentials: { + apiKey?: string; + baseUrl?: string; + model?: string; + }): void { + /** + * If any fields are updated here, we treat the resulting config as manually overridden + * and avoid applying modelProvider defaults during the next auth refresh. + * + * Clear all provider-sourced configuration to maintain provider atomicity. + * This ensures that when user manually sets credentials, the provider config + * is either fully applied (via switchModel) or not at all. + */ + if (credentials.apiKey || credentials.baseUrl || credentials.model) { + this.hasManualCredentials = true; + this.clearProviderSourcedConfig(); + } + + if (credentials.apiKey) { + this._generationConfig.apiKey = credentials.apiKey; + this.generationConfigSources['apiKey'] = { + kind: 'programmatic', + detail: 'updateCredentials', + }; + } + if (credentials.baseUrl) { + this._generationConfig.baseUrl = credentials.baseUrl; + this.generationConfigSources['baseUrl'] = { + kind: 'programmatic', + detail: 'updateCredentials', + }; + } + if (credentials.model) { + this._generationConfig.model = credentials.model; + this.currentModelId = credentials.model; + this.generationConfigSources['model'] = { + kind: 'programmatic', + detail: 'updateCredentials', + }; + } + // When credentials are manually set, disable strict model provider selection + // so validation doesn't require envKey-based credentials + this.strictModelProviderSelection = false; + // Clear apiKeyEnvKey to prevent validation from requiring environment variable + this._generationConfig.apiKeyEnvKey = undefined; + } + + /** + * Clear configuration fields that were sourced from modelProviders. + * This ensures provider config atomicity when user manually sets credentials. + * Other layers (CLI, env, settings, defaults) will participate in resolve. + */ + private clearProviderSourcedConfig(): void { + for (const field of PROVIDER_SOURCED_FIELDS) { + const source = this.generationConfigSources[field]; + if (source?.kind === 'modelProviders') { + // Clear the value - let other layers resolve it + delete (this._generationConfig as Record)[field]; + delete this.generationConfigSources[field]; + } + } + } + + /** + * Get whether strict model provider selection is enabled + */ + isStrictModelProviderSelection(): boolean { + return this.strictModelProviderSelection; + } + + /** + * Reset strict model provider selection flag + */ + resetStrictModelProviderSelection(): void { + this.strictModelProviderSelection = false; + } + + /** + * Check and consume the one-shot cached credentials flag + */ + consumeRequireCachedCredentialsFlag(): boolean { + const value = this.requireCachedQwenCredentialsOnce; + this.requireCachedQwenCredentialsOnce = false; + return value; + } + + /** + * Apply resolved model config to generation config + */ + private applyResolvedModelDefaults(model: ResolvedModelConfig): void { + this.strictModelProviderSelection = true; + const previousApiKey = this._generationConfig.apiKey; + const previousApiKeyEnvKey = this._generationConfig.apiKeyEnvKey; + const hadManualCredentials = this.hasManualCredentials; + // We're explicitly applying modelProvider defaults now, so manual overrides + // should no longer block syncAfterAuthRefresh from applying provider defaults. + this.hasManualCredentials = false; + + this._generationConfig.model = model.id; + this.generationConfigSources['model'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'model.id', + }; + + // Clear credentials to avoid reusing previous model's API key + + // For Qwen OAuth, apiKey must always be a placeholder. It will be dynamically + // replaced when building requests. Do not preserve any previous key or read + // from envKey. + // + // (OpenAI client instantiation requires an apiKey even though it will be + // replaced later.) + if (this.currentAuthType === AuthType.QWEN_OAUTH) { + this._generationConfig.apiKey = 'QWEN_OAUTH_DYNAMIC_TOKEN'; + this.generationConfigSources['apiKey'] = { + kind: 'computed', + detail: 'Qwen OAuth placeholder token', + }; + this._generationConfig.apiKeyEnvKey = undefined; + delete this.generationConfigSources['apiKeyEnvKey']; + } else { + this._generationConfig.apiKey = undefined; + this._generationConfig.apiKeyEnvKey = undefined; + } + + // Read API key from environment variable if envKey is specified + if (model.envKey !== undefined) { + const apiKey = process.env[model.envKey]; + if (apiKey) { + this._generationConfig.apiKey = apiKey; + this.generationConfigSources['apiKey'] = { + kind: 'env', + envKey: model.envKey, + via: { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'envKey', + }, + }; + } else { + // If the user provided an API key via CLI/settings/updateCredentials, keep it. + // We only refuse to reuse a previous key when it is explicitly tied to a + // different envKey (e.g. switching between two configured accounts). + const canPreservePreviousKey = + !!previousApiKey && + (hadManualCredentials || + previousApiKeyEnvKey === undefined || + previousApiKeyEnvKey === model.envKey); + + if (canPreservePreviousKey) { + this._generationConfig.apiKey = previousApiKey; + this.generationConfigSources['apiKey'] = { + kind: 'computed', + detail: `preserved previous apiKey (missing env: ${model.envKey})`, + }; + } else { + console.warn( + `[ModelsConfig] Environment variable '${model.envKey}' is not set for model '${model.id}'. ` + + `API key will not be available.`, + ); + } + } + this._generationConfig.apiKeyEnvKey = model.envKey; + this.generationConfigSources['apiKeyEnvKey'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'envKey', + }; + } + + // Base URL + this._generationConfig.baseUrl = model.baseUrl; + this.generationConfigSources['baseUrl'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'baseUrl', + }; + + // Generation config + const gc = model.generationConfig; + this._generationConfig.samplingParams = { ...(gc.samplingParams || {}) }; + this.generationConfigSources['samplingParams'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'generationConfig.samplingParams', + }; + + this._generationConfig.timeout = gc.timeout; + this.generationConfigSources['timeout'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'generationConfig.timeout', + }; + + this._generationConfig.maxRetries = gc.maxRetries; + this.generationConfigSources['maxRetries'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'generationConfig.maxRetries', + }; + + this._generationConfig.disableCacheControl = gc.disableCacheControl; + this.generationConfigSources['disableCacheControl'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'generationConfig.disableCacheControl', + }; + + this._generationConfig.schemaCompliance = gc.schemaCompliance; + this.generationConfigSources['schemaCompliance'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'generationConfig.schemaCompliance', + }; + + this._generationConfig.reasoning = gc.reasoning; + this.generationConfigSources['reasoning'] = { + kind: 'modelProviders', + authType: model.authType, + modelId: model.id, + detail: 'generationConfig.reasoning', + }; + } + + /** + * Check if model switch requires ContentGenerator refresh. + * + * Note: This method is ONLY called by switchModel() for same-authType model switches. + * Cross-authType switches use switchModel(authType, modelId), which always requires full refresh. + * + * When this method is called: + * - this.currentAuthType is already the target authType + * - We're checking if switching between two models within the SAME authType needs refresh + * + * Examples: + * - Qwen OAuth: coder-model -> vision-model (same authType, hot-update safe) + * - OpenAI: model-a -> model-b with same envKey (same authType, hot-update safe) + * - OpenAI: gpt-4 -> deepseek-chat with different envKey (same authType, needs refresh) + * + * Cross-authType scenarios: + * - OpenAI -> Qwen OAuth: handled by switchModel(authType, modelId), always refreshes + * - Qwen OAuth -> OpenAI: handled by switchModel(authType, modelId), always refreshes + */ + private checkRequiresRefresh(previousModelId: string): boolean { + // For Qwen OAuth, model switches within the same authType can always be hot-updated + // (coder-model <-> vision-model don't require ContentGenerator recreation) + if (this.currentAuthType === AuthType.QWEN_OAUTH) { + return false; + } + + // Get previous and current model configs + const previousModel = this.modelRegistry.getModel( + this.currentAuthType, + previousModelId, + ); + const currentModel = this.modelRegistry.getModel( + this.currentAuthType, + this.currentModelId, + ); + + // If either model is not in registry, require refresh to be safe + if (!previousModel || !currentModel) { + return true; + } + + // Check if critical fields changed that require ContentGenerator recreation + const criticalFieldsChanged = + previousModel.envKey !== currentModel.envKey || + previousModel.baseUrl !== currentModel.baseUrl; + + if (criticalFieldsChanged) { + return true; + } + + // For other auth types with strict model provider selection, + // if no critical fields changed, we can still hot-update + // (e.g., switching between two OpenAI models with same envKey and baseUrl) + return false; + } + + /** + * Called by Config.refreshAuth to sync state after auth refresh. + * + * IMPORTANT: If credentials were manually set via updateCredentials(), + * we should NOT override them with modelProvider defaults. + * This handles the case where user inputs credentials via OpenAIKeyPrompt + * after removing environment variables for a previously selected model. + */ + syncAfterAuthRefresh(authType: AuthType, modelId?: string): void { + // Check if we have manually set credentials that should be preserved + const preserveManualCredentials = this.hasManualCredentials; + + // If credentials were manually set, don't apply modelProvider defaults + // Just update the authType and preserve the manually set credentials + if (preserveManualCredentials) { + this.strictModelProviderSelection = false; + this.currentAuthType = authType; + if (modelId) { + this.currentModelId = modelId; + } + return; + } + + this.strictModelProviderSelection = false; + + if (modelId && this.modelRegistry.hasModel(authType, modelId)) { + const resolved = this.modelRegistry.getModel(authType, modelId); + if (resolved) { + this.applyResolvedModelDefaults(resolved); + this.currentAuthType = authType; + this.currentModelId = modelId; + } + } else { + this.currentAuthType = authType; + } + } + + /** + * Update callback for model changes + */ + setOnModelChange(callback: OnModelChangeCallback): void { + this.onModelChange = callback; + } +} diff --git a/packages/core/src/models/types.ts b/packages/core/src/models/types.ts new file mode 100644 index 000000000..b5ce56efa --- /dev/null +++ b/packages/core/src/models/types.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + AuthType, + ContentGeneratorConfig, +} from '../core/contentGenerator.js'; + +/** + * Model capabilities configuration + */ +export interface ModelCapabilities { + /** Supports image/vision inputs */ + vision?: boolean; +} + +/** + * Model-scoped generation configuration. + * + * Keep this consistent with {@link ContentGeneratorConfig} so modelProviders can + * feed directly into content generator resolution without shape conversion. + */ +export type ModelGenerationConfig = Pick< + ContentGeneratorConfig, + | 'samplingParams' + | 'timeout' + | 'maxRetries' + | 'disableCacheControl' + | 'schemaCompliance' + | 'reasoning' +>; + +/** + * Model configuration for a single model within an authType + */ +export interface ModelConfig { + /** Unique model ID within authType (e.g., "qwen-coder", "gpt-4-turbo") */ + id: string; + /** Display name (defaults to id) */ + name?: string; + /** Model description */ + description?: string; + /** Environment variable name to read API key from (e.g., "OPENAI_API_KEY") */ + envKey?: string; + /** API endpoint override */ + baseUrl?: string; + /** Model capabilities, reserve for future use. Now we do not read this to determine multi-modal support or other capabilities. */ + capabilities?: ModelCapabilities; + /** Generation configuration (sampling parameters) */ + generationConfig?: ModelGenerationConfig; +} + +/** + * Model providers configuration grouped by authType + */ +export type ModelProvidersConfig = { + [authType: string]: ModelConfig[]; +}; + +/** + * Resolved model config with all defaults applied + */ +export interface ResolvedModelConfig extends ModelConfig { + /** AuthType this model belongs to (always present from map key) */ + authType: AuthType; + /** Display name (always present, defaults to id) */ + name: string; + /** Environment variable name to read API key from (optional, provider-specific) */ + envKey?: string; + /** API base URL (always present, has default per authType) */ + baseUrl: string; + /** Generation config (always present, merged with defaults) */ + generationConfig: ModelGenerationConfig; + /** Capabilities (always present, defaults to {}) */ + capabilities: ModelCapabilities; +} + +/** + * Model info for UI display + */ +export interface AvailableModel { + id: string; + label: string; + description?: string; + capabilities?: ModelCapabilities; + authType: AuthType; + isVision?: boolean; +} + +/** + * Metadata for model switch operations + */ +export interface ModelSwitchMetadata { + /** Reason for the switch */ + reason?: string; + /** Additional context */ + context?: string; +} diff --git a/packages/core/src/subagents/subagent.test.ts b/packages/core/src/subagents/subagent.test.ts index 742813cdb..a23e787ef 100644 --- a/packages/core/src/subagents/subagent.test.ts +++ b/packages/core/src/subagents/subagent.test.ts @@ -22,10 +22,11 @@ import { type Mock, } from 'vitest'; import { Config, type ConfigParameters } from '../config/config.js'; -import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; +import { DEFAULT_QWEN_MODEL } from '../config/models.js'; import { createContentGenerator, createContentGeneratorConfig, + resolveContentGeneratorConfigWithSources, AuthType, } from '../core/contentGenerator.js'; import { GeminiChat } from '../core/geminiChat.js'; @@ -42,7 +43,33 @@ import type { import { SubagentTerminateMode } from './types.js'; vi.mock('../core/geminiChat.js'); -vi.mock('../core/contentGenerator.js'); +vi.mock('../core/contentGenerator.js', async (importOriginal) => { + const actual = + await importOriginal(); + const { DEFAULT_QWEN_MODEL } = await import('../config/models.js'); + return { + ...actual, + createContentGenerator: vi.fn().mockResolvedValue({ + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), + embedContent: vi.fn(), + useSummarizedThinking: vi.fn().mockReturnValue(false), + }), + createContentGeneratorConfig: vi.fn().mockReturnValue({ + model: DEFAULT_QWEN_MODEL, + authType: actual.AuthType.USE_GEMINI, + }), + resolveContentGeneratorConfigWithSources: vi.fn().mockReturnValue({ + config: { + model: DEFAULT_QWEN_MODEL, + authType: actual.AuthType.USE_GEMINI, + apiKey: 'test-api-key', + }, + sources: {}, + }), + }; +}); vi.mock('../utils/environmentContext.js', () => ({ getEnvironmentContext: vi.fn().mockResolvedValue([{ text: 'Env Context' }]), getInitialChatHistory: vi.fn(async (_config, extraHistory) => [ @@ -65,7 +92,7 @@ async function createMockConfig( toolRegistryMocks = {}, ): Promise<{ config: Config; toolRegistry: ToolRegistry }> { const configParams: ConfigParameters = { - model: DEFAULT_GEMINI_MODEL, + model: DEFAULT_QWEN_MODEL, targetDir: '.', debugMode: false, cwd: process.cwd(), @@ -89,7 +116,7 @@ async function createMockConfig( // Mock getContentGeneratorConfig to return a valid config vi.spyOn(config, 'getContentGeneratorConfig').mockReturnValue({ - model: DEFAULT_GEMINI_MODEL, + model: DEFAULT_QWEN_MODEL, authType: AuthType.USE_GEMINI, }); @@ -192,9 +219,17 @@ describe('subagent.ts', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any } as any); vi.mocked(createContentGeneratorConfig).mockReturnValue({ - model: DEFAULT_GEMINI_MODEL, + model: DEFAULT_QWEN_MODEL, authType: undefined, }); + vi.mocked(resolveContentGeneratorConfigWithSources).mockReturnValue({ + config: { + model: DEFAULT_QWEN_MODEL, + authType: AuthType.USE_GEMINI, + apiKey: 'test-api-key', + }, + sources: {}, + }); mockSendMessageStream = vi.fn(); vi.mocked(GeminiChat).mockImplementation( diff --git a/packages/core/src/utils/configResolver.test.ts b/packages/core/src/utils/configResolver.test.ts new file mode 100644 index 000000000..ee992cd67 --- /dev/null +++ b/packages/core/src/utils/configResolver.test.ts @@ -0,0 +1,141 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { + resolveField, + resolveOptionalField, + layer, + envLayer, + cliSource, + settingsSource, + defaultSource, +} from './configResolver.js'; + +describe('configResolver', () => { + describe('resolveField', () => { + it('returns first present value from layers', () => { + const result = resolveField( + [ + layer(undefined, cliSource('--model')), + envLayer({ MODEL: 'from-env' }, 'MODEL'), + layer('from-settings', settingsSource('model.name')), + ], + 'default-model', + ); + + expect(result.value).toBe('from-env'); + expect(result.source).toEqual({ kind: 'env', envKey: 'MODEL' }); + }); + + it('returns default when all layers are undefined', () => { + const result = resolveField( + [layer(undefined, cliSource('--model')), envLayer({}, 'MODEL')], + 'default-model', + defaultSource('default-model'), + ); + + expect(result.value).toBe('default-model'); + expect(result.source).toEqual({ + kind: 'default', + detail: 'default-model', + }); + }); + + it('respects layer priority order', () => { + const result = resolveField( + [ + layer('cli-value', cliSource('--model')), + envLayer({ MODEL: 'env-value' }, 'MODEL'), + layer('settings-value', settingsSource('model.name')), + ], + 'default', + ); + + expect(result.value).toBe('cli-value'); + expect(result.source.kind).toBe('cli'); + }); + + it('skips empty strings', () => { + const result = resolveField( + [ + layer('', cliSource('--model')), + envLayer({ MODEL: 'env-value' }, 'MODEL'), + ], + 'default', + ); + + expect(result.value).toBe('env-value'); + }); + }); + + describe('resolveOptionalField', () => { + it('returns undefined when no value present', () => { + const result = resolveOptionalField([ + layer(undefined, cliSource('--key')), + envLayer({}, 'KEY'), + ]); + + expect(result).toBeUndefined(); + }); + + it('returns first present value', () => { + const result = resolveOptionalField([ + layer(undefined, cliSource('--key')), + envLayer({ KEY: 'found' }, 'KEY'), + ]); + + expect(result).toBeDefined(); + expect(result!.value).toBe('found'); + expect(result!.source.kind).toBe('env'); + }); + }); + + describe('envLayer', () => { + it('creates layer from environment variable', () => { + const env = { MY_VAR: 'my-value' }; + const result = envLayer(env, 'MY_VAR'); + + expect(result.value).toBe('my-value'); + expect(result.source).toEqual({ kind: 'env', envKey: 'MY_VAR' }); + }); + + it('handles missing environment variable', () => { + const env = {}; + const result = envLayer(env, 'MISSING_VAR'); + + expect(result.value).toBeUndefined(); + expect(result.source).toEqual({ kind: 'env', envKey: 'MISSING_VAR' }); + }); + + it('supports transform function', () => { + const env = { PORT: '3000' }; + const result = envLayer(env, 'PORT', (v) => parseInt(v, 10)); + + expect(result.value).toBe(3000); + }); + }); + + describe('source factory functions', () => { + it('creates CLI source', () => { + expect(cliSource('--model')).toEqual({ kind: 'cli', detail: '--model' }); + }); + + it('creates settings source', () => { + expect(settingsSource('model.name')).toEqual({ + kind: 'settings', + settingsPath: 'model.name', + }); + }); + + it('creates default source', () => { + expect(defaultSource('my-default')).toEqual({ + kind: 'default', + detail: 'my-default', + }); + }); + }); +}); diff --git a/packages/core/src/utils/configResolver.ts b/packages/core/src/utils/configResolver.ts new file mode 100644 index 000000000..209052f5a --- /dev/null +++ b/packages/core/src/utils/configResolver.ts @@ -0,0 +1,222 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Generic multi-source configuration resolver utilities. + * + * This module provides reusable tools for resolving configuration values + * from multiple sources (CLI, env, settings, etc.) with priority ordering + * and source tracking. + */ + +/** + * Known source kinds for configuration values. + * Extensible for domain-specific needs. + */ +export type ConfigSourceKind = + | 'cli' + | 'env' + | 'settings' + | 'modelProviders' + | 'default' + | 'computed' + | 'programmatic' + | 'unknown'; + +/** + * Source metadata for a configuration value. + * Tracks where the value came from for debugging and UI display. + */ +export interface ConfigSource { + /** The kind/category of the source */ + kind: ConfigSourceKind; + /** Additional detail about the source (e.g., '--model' for CLI) */ + detail?: string; + /** Environment variable key if kind is 'env' */ + envKey?: string; + /** Settings path if kind is 'settings' (e.g., 'model.name') */ + settingsPath?: string; + /** Auth type if relevant (for modelProviders) */ + authType?: string; + /** Model ID if relevant (for modelProviders) */ + modelId?: string; + /** Indirect source - when a value is derived via another source */ + via?: Omit; +} + +/** + * Map of field names to their sources + */ +export type ConfigSources = Record; + +/** + * A configuration layer represents a potential source for a value. + * Layers are evaluated in priority order (first non-undefined wins). + */ +export interface ConfigLayer { + /** The value from this layer (undefined means not present) */ + value: T | undefined; + /** Source metadata for this layer */ + source: ConfigSource; +} + +/** + * Result of resolving a single field + */ +export interface ResolvedField { + /** The resolved value */ + value: T; + /** Source metadata indicating where the value came from */ + source: ConfigSource; +} + +/** + * Resolve a single configuration field from multiple layers. + * + * Layers are evaluated in order. The first layer with a defined, + * non-empty value wins. If no layer has a value, the default is used. + * + * @param layers - Configuration layers in priority order (highest first) + * @param defaultValue - Default value if no layer provides one + * @param defaultSource - Source metadata for the default value + * @returns The resolved value and its source + * + * @example + * ```typescript + * const model = resolveField( + * [ + * { value: argv.model, source: { kind: 'cli', detail: '--model' } }, + * { value: env['OPENAI_MODEL'], source: { kind: 'env', envKey: 'OPENAI_MODEL' } }, + * { value: settings.model, source: { kind: 'settings', settingsPath: 'model.name' } }, + * ], + * 'default-model', + * { kind: 'default', detail: 'default-model' } + * ); + * ``` + */ +export function resolveField( + layers: Array>, + defaultValue: T, + defaultSource: ConfigSource = { kind: 'default' }, +): ResolvedField { + for (const layer of layers) { + if (isValuePresent(layer.value)) { + return { value: layer.value, source: layer.source }; + } + } + return { value: defaultValue, source: defaultSource }; +} + +/** + * Resolve a field that may not have a default (optional field). + * + * @param layers - Configuration layers in priority order + * @returns The resolved value and source, or undefined if not found + */ +export function resolveOptionalField( + layers: Array>, +): ResolvedField | undefined { + for (const layer of layers) { + if (isValuePresent(layer.value)) { + return { value: layer.value, source: layer.source }; + } + } + return undefined; +} + +/** + * Check if a value is "present" (not undefined, not null, not empty string). + * + * @param value - The value to check + * @returns true if the value should be considered present + */ +function isValuePresent(value: T | undefined | null): value is T { + if (value === undefined || value === null) { + return false; + } + // Treat empty strings as not present + if (typeof value === 'string' && value.trim() === '') { + return false; + } + return true; +} + +/** + * Create a CLI source descriptor + */ +export function cliSource(detail: string): ConfigSource { + return { kind: 'cli', detail }; +} + +/** + * Create an environment variable source descriptor + */ +function envSource(envKey: string): ConfigSource { + return { kind: 'env', envKey }; +} + +/** + * Create a settings source descriptor + */ +export function settingsSource(settingsPath: string): ConfigSource { + return { kind: 'settings', settingsPath }; +} + +/** + * Create a modelProviders source descriptor + */ +export function modelProvidersSource( + authType: string, + modelId: string, + detail?: string, +): ConfigSource { + return { kind: 'modelProviders', authType, modelId, detail }; +} + +/** + * Create a default value source descriptor + */ +export function defaultSource(detail?: string): ConfigSource { + return { kind: 'default', detail }; +} + +/** + * Create a computed value source descriptor + */ +export function computedSource(detail?: string): ConfigSource { + return { kind: 'computed', detail }; +} + +/** + * Create a layer from an environment variable + */ +export function envLayer( + env: Record, + key: string, + transform?: (value: string) => T, +): ConfigLayer { + const rawValue = env[key]; + const value = + rawValue !== undefined + ? transform + ? transform(rawValue) + : (rawValue as unknown as T) + : undefined; + return { + value, + source: envSource(key), + }; +} + +/** + * Create a layer with a static value and source + */ +export function layer( + value: T | undefined, + source: ConfigSource, +): ConfigLayer { + return { value, source }; +} diff --git a/packages/core/src/utils/flashFallback.test.ts b/packages/core/src/utils/flashFallback.test.ts deleted file mode 100644 index 184cb2037..000000000 --- a/packages/core/src/utils/flashFallback.test.ts +++ /dev/null @@ -1,75 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { describe, it, expect, beforeEach, vi } from 'vitest'; -import { Config } from '../config/config.js'; -import fs from 'node:fs'; -import { - setSimulate429, - disableSimulationAfterFallback, - shouldSimulate429, - resetRequestCounter, -} from './testUtils.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; -// Import the new types (Assuming this test file is in packages/core/src/utils/) -import type { FallbackModelHandler } from '../fallback/types.js'; - -vi.mock('node:fs'); - -// Update the description to reflect that this tests the retry utility's integration -describe('Retry Utility Fallback Integration', () => { - let config: Config; - - beforeEach(() => { - vi.mocked(fs.existsSync).mockReturnValue(true); - vi.mocked(fs.statSync).mockReturnValue({ - isDirectory: () => true, - } as fs.Stats); - config = new Config({ - targetDir: '/test', - debugMode: false, - cwd: '/test', - model: 'gemini-2.5-pro', - }); - - // Reset simulation state for each test - setSimulate429(false); - resetRequestCounter(); - }); - - // This test validates the Config's ability to store and execute the handler contract. - it('should execute the injected FallbackHandler contract correctly', async () => { - // Set up a minimal handler for testing, ensuring it matches the new type. - const fallbackHandler: FallbackModelHandler = async () => 'retry'; - - // Use the generalized setter - config.setFallbackModelHandler(fallbackHandler); - - // Call the handler directly via the config property - const result = await config.fallbackModelHandler!( - 'gemini-2.5-pro', - DEFAULT_GEMINI_FLASH_MODEL, - ); - - // Verify it returns the correct intent - expect(result).toBe('retry'); - }); - - // This test validates the test utilities themselves. - it('should properly disable simulation state after fallback (Test Utility)', () => { - // Enable simulation - setSimulate429(true); - - // Verify simulation is enabled - expect(shouldSimulate429()).toBe(true); - - // Disable simulation after fallback - disableSimulationAfterFallback(); - - // Verify simulation is now disabled - expect(shouldSimulate429()).toBe(false); - }); -}); diff --git a/packages/core/src/utils/llm-edit-fixer.ts b/packages/core/src/utils/llm-edit-fixer.ts index 467fdbdb9..bc81c8e62 100644 --- a/packages/core/src/utils/llm-edit-fixer.ts +++ b/packages/core/src/utils/llm-edit-fixer.ts @@ -8,7 +8,7 @@ import { createHash } from 'node:crypto'; import { type Content, Type } from '@google/genai'; import { type BaseLlmClient } from '../core/baseLlmClient.js'; import { LruCache } from './LruCache.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js'; import { promptIdContext } from './promptIdContext.js'; const MAX_CACHE_SIZE = 50; @@ -149,7 +149,7 @@ export async function FixLLMEditWithInstruction( contents, schema: SearchReplaceEditSchema, abortSignal, - model: DEFAULT_GEMINI_FLASH_MODEL, + model: DEFAULT_QWEN_FLASH_MODEL, systemInstruction: EDIT_SYS_PROMPT, promptId, maxAttempts: 1, diff --git a/packages/core/src/utils/summarizer.ts b/packages/core/src/utils/summarizer.ts index 14076b5c2..c5290cfa2 100644 --- a/packages/core/src/utils/summarizer.ts +++ b/packages/core/src/utils/summarizer.ts @@ -11,7 +11,7 @@ import type { GenerateContentResponse, } from '@google/genai'; import type { GeminiClient } from '../core/client.js'; -import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js'; +import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js'; import { getResponseText, partToString } from './partUtils.js'; /** @@ -86,7 +86,7 @@ export async function summarizeToolOutput( contents, toolOutputSummarizerConfig, abortSignal, - DEFAULT_GEMINI_FLASH_LITE_MODEL, + DEFAULT_QWEN_FLASH_MODEL, )) as unknown as GenerateContentResponse; return getResponseText(parsedResponse) || textToSummarize; } catch (error) {