diff --git a/packages/cli/src/commands/hooks.tsx b/packages/cli/src/commands/hooks.tsx new file mode 100644 index 000000000..c747c61c2 --- /dev/null +++ b/packages/cli/src/commands/hooks.tsx @@ -0,0 +1,25 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { CommandModule } from 'yargs'; +import { enableCommand } from './hooks/enable.js'; +import { disableCommand } from './hooks/disable.js'; + +export const hooksCommand: CommandModule = { + command: 'hooks ', + aliases: ['hook'], + describe: 'Manage Qwen Code hooks.', + builder: (yargs) => + yargs + .command(enableCommand) + .command(disableCommand) + .demandCommand(1, 'You need at least one command before continuing.') + .version(false), + handler: () => { + // This handler is not called when a subcommand is provided. + // Yargs will show the help menu. + }, +}; diff --git a/packages/cli/src/commands/hooks/disable.ts b/packages/cli/src/commands/hooks/disable.ts new file mode 100644 index 000000000..8d1324cdb --- /dev/null +++ b/packages/cli/src/commands/hooks/disable.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { CommandModule } from 'yargs'; +import { createDebugLogger, getErrorMessage } from '@qwen-code/qwen-code-core'; +import { loadSettings, SettingScope } from '../../config/settings.js'; + +const debugLogger = createDebugLogger('HOOKS_DISABLE'); + +interface DisableArgs { + hookName: string; +} + +/** + * Disable a hook by adding it to the disabled list + */ +export async function handleDisableHook(hookName: string): Promise { + const workingDir = process.cwd(); + const settings = loadSettings(workingDir); + + try { + // Get current hooks settings + const mergedSettings = settings.merged as + | Record + | undefined; + const hooksSettings = (mergedSettings?.['hooks'] || {}) as Record< + string, + unknown + >; + const disabledHooks = (hooksSettings['disabled'] || []) as string[]; + + // Check if hook is already disabled + if (disabledHooks.includes(hookName)) { + debugLogger.info(`Hook "${hookName}" is already disabled.`); + return; + } + + // Add hook to disabled list + const newDisabledHooks = [...disabledHooks, hookName]; + const newHooksSettings = { + ...hooksSettings, + disabled: newDisabledHooks, + }; + + // Save updated settings + settings.setValue( + SettingScope.Workspace, + 'hooks' as keyof typeof settings.merged, + newHooksSettings as never, + ); + + debugLogger.info(`✓ Hook "${hookName}" has been disabled.`); + } catch (error) { + debugLogger.error(`Error disabling hook: ${getErrorMessage(error)}`); + } +} + +export const disableCommand: CommandModule = { + command: 'disable ', + describe: 'Disable an active hook', + builder: (yargs) => + yargs.positional('hook-name', { + describe: 'Name of the hook to disable', + type: 'string', + demandOption: true, + }), + handler: async (argv) => { + const args = argv as unknown as DisableArgs; + await handleDisableHook(args.hookName); + process.exit(0); + }, +}; diff --git a/packages/cli/src/commands/hooks/enable.ts b/packages/cli/src/commands/hooks/enable.ts new file mode 100644 index 000000000..863b5b32c --- /dev/null +++ b/packages/cli/src/commands/hooks/enable.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { CommandModule } from 'yargs'; +import { createDebugLogger, getErrorMessage } from '@qwen-code/qwen-code-core'; +import { loadSettings, SettingScope } from '../../config/settings.js'; + +const debugLogger = createDebugLogger('HOOKS_ENABLE'); + +interface EnableArgs { + hookName: string; +} + +/** + * Enable a hook by removing it from the disabled list + */ +export async function handleEnableHook(hookName: string): Promise { + const workingDir = process.cwd(); + const settings = loadSettings(workingDir); + + try { + // Get current hooks settings + const mergedSettings = settings.merged as + | Record + | undefined; + const hooksSettings = (mergedSettings?.['hooks'] || {}) as Record< + string, + unknown + >; + const disabledHooks = (hooksSettings['disabled'] || []) as string[]; + + // Check if hook is in disabled list + if (!disabledHooks.includes(hookName)) { + debugLogger.info(`Hook "${hookName}" is not disabled.`); + return; + } + + // Remove hook from disabled list + const newDisabledHooks = disabledHooks.filter((h) => h !== hookName); + const newHooksSettings = { + ...hooksSettings, + disabled: newDisabledHooks, + }; + + // Save updated settings + settings.setValue( + SettingScope.Workspace, + 'hooks' as keyof typeof settings.merged, + newHooksSettings as never, + ); + + debugLogger.info(`✓ Hook "${hookName}" has been enabled.`); + } catch (error) { + debugLogger.error(`Error enabling hook: ${getErrorMessage(error)}`); + } +} + +export const enableCommand: CommandModule = { + command: 'enable ', + describe: 'Enable a disabled hook', + builder: (yargs) => + yargs.positional('hook-name', { + describe: 'Name of the hook to enable', + type: 'string', + demandOption: true, + }), + handler: async (argv) => { + const args = argv as unknown as EnableArgs; + await handleEnableHook(args.hookName); + process.exit(0); + }, +}; diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index c31ffa216..2805c32a2 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -33,6 +33,7 @@ import { NativeLspService, } from '@qwen-code/qwen-code-core'; import { extensionsCommand } from '../commands/extensions.js'; +import { hooksCommand } from '../commands/hooks.js'; import type { Settings } from './settings.js'; import { resolveCliGenerationConfig, @@ -569,7 +570,9 @@ export async function parseArguments(): Promise { // Register MCP subcommands .command(mcpCommand) // Register Extension subcommands - .command(extensionsCommand); + .command(extensionsCommand) + // Register Hooks subcommands + .command(hooksCommand); yargsInstance .version(await getCliVersion()) // This will enable the --version flag based on package.json @@ -588,9 +591,11 @@ export async function parseArguments(): Promise { // and not return to main CLI logic if ( result._.length > 0 && - (result._[0] === 'mcp' || result._[0] === 'extensions') + (result._[0] === 'mcp' || + result._[0] === 'extensions' || + result._[0] === 'hooks') ) { - // MCP commands handle their own execution and process exit + // MCP/Extensions/Hooks commands handle their own execution and process exit process.exit(0); } @@ -1025,6 +1030,7 @@ export async function loadCliConfig( output: { format: outputSettingsFormat, }, + hooks: settings.hooks, channel: argv.channel, // Precedence: explicit CLI flag > settings file > default(true). // NOTE: do NOT set a yargs default for `chat-recording`, otherwise argv will diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 283baee26..87a521e75 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -1177,6 +1177,118 @@ const SETTINGS_SCHEMA = { showInDialog: false, }, + hooks: { + type: 'object', + label: 'Hooks', + category: 'Advanced', + requiresRestart: false, + default: {}, + description: + 'Hook configurations for extending CLI behavior at various lifecycle points.', + showInDialog: false, + properties: { + disabled: { + type: 'array', + label: 'Disabled Hooks', + category: 'Advanced', + requiresRestart: false, + default: [] as string[], + description: + 'List of hook names to disable. Hooks in this list will not be executed.', + showInDialog: false, + mergeStrategy: MergeStrategy.UNION, + }, + PreToolUse: { + type: 'array', + label: 'PreTool Use Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute before tool invocations. Can validate, modify, or block tool calls.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + PostToolUse: { + type: 'array', + label: 'PostTool Use Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute after tool invocations. Can process results or trigger follow-up actions.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + BeforeAgent: { + type: 'array', + label: 'Before Agent Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute before agent processing. Can modify prompts or inject context.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + AfterAgent: { + type: 'array', + label: 'After Agent Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute after agent processing. Can post-process responses or log interactions.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + SessionStart: { + type: 'array', + label: 'Session Start Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute when a session starts. Can initialize state or load context.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + SessionEnd: { + type: 'array', + label: 'Session End Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute when a session ends. Can perform cleanup or persist session data.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + PreCompact: { + type: 'array', + label: 'PreCompact Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute before chat history compression. Can back up or analyze conversation before compression.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + Notification: { + type: 'array', + label: 'Notification Hooks', + category: 'Advanced', + requiresRestart: false, + default: [], + description: + 'Hooks that execute when notifications are triggered. Can handle alerts or status updates.', + showInDialog: false, + mergeStrategy: MergeStrategy.CONCAT, + }, + }, + }, + experimental: { type: 'object', label: 'Experimental', diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index dc4c1f8d9..9b2983be3 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -21,6 +21,7 @@ import { editorCommand } from '../ui/commands/editorCommand.js'; import { exportCommand } from '../ui/commands/exportCommand.js'; import { extensionsCommand } from '../ui/commands/extensionsCommand.js'; import { helpCommand } from '../ui/commands/helpCommand.js'; +import { hooksCommand } from '../ui/commands/hooksCommand.js'; import { ideCommand } from '../ui/commands/ideCommand.js'; import { initCommand } from '../ui/commands/initCommand.js'; import { languageCommand } from '../ui/commands/languageCommand.js'; @@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader { exportCommand, extensionsCommand, helpCommand, + hooksCommand, await ideCommand(), initCommand, languageCommand, diff --git a/packages/cli/src/ui/commands/hooksCommand.ts b/packages/cli/src/ui/commands/hooksCommand.ts new file mode 100644 index 000000000..926b01a95 --- /dev/null +++ b/packages/cli/src/ui/commands/hooksCommand.ts @@ -0,0 +1,320 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + SlashCommand, + SlashCommandActionReturn, + CommandContext, + MessageActionReturn, +} from './types.js'; +import { CommandKind } from './types.js'; +import { t } from '../../i18n/index.js'; +import type { HookRegistryEntry } from '@qwen-code/qwen-code-core'; + +/** + * Format hook source for display + */ +function formatHookSource(source: string): string { + switch (source) { + case 'project': + return 'Project'; + case 'user': + return 'User'; + case 'system': + return 'System'; + case 'extensions': + return 'Extension'; + default: + return source; + } +} + +/** + * Format hook status for display + */ +function formatHookStatus(enabled: boolean): string { + return enabled ? '✓ Enabled' : '✗ Disabled'; +} + +const listCommand: SlashCommand = { + name: 'list', + get description() { + return t('List all configured hooks'); + }, + kind: CommandKind.BUILT_IN, + action: async ( + context: CommandContext, + _args: string, + ): Promise => { + const { config } = context.services; + if (!config) { + return { + type: 'message', + messageType: 'error', + content: t('Config not loaded.'), + }; + } + + const hookSystem = config.getHookSystem(); + if (!hookSystem) { + return { + type: 'message', + messageType: 'info', + content: t( + 'Hooks are not enabled. Enable hooks in settings to use this feature.', + ), + }; + } + + const registry = hookSystem.getRegistry(); + const allHooks = registry.getAllHooks(); + + if (allHooks.length === 0) { + return { + type: 'message', + messageType: 'info', + content: t( + 'No hooks configured. Add hooks in your settings.json file.', + ), + }; + } + + // Group hooks by event + const hooksByEvent = new Map(); + for (const hook of allHooks) { + const eventName = hook.eventName; + if (!hooksByEvent.has(eventName)) { + hooksByEvent.set(eventName, []); + } + hooksByEvent.get(eventName)!.push(hook); + } + + let output = `**Configured Hooks (${allHooks.length} total)**\n\n`; + + for (const [eventName, hooks] of hooksByEvent) { + output += `### ${eventName}\n`; + for (const hook of hooks) { + const name = hook.config.name || hook.config.command || 'unnamed'; + const source = formatHookSource(hook.source); + const status = formatHookStatus(hook.enabled); + const matcher = hook.matcher ? ` (matcher: ${hook.matcher})` : ''; + output += `- **${name}** [${source}] ${status}${matcher}\n`; + } + output += '\n'; + } + + return { + type: 'message', + messageType: 'info', + content: output, + }; + }, +}; + +const enableCommand: SlashCommand = { + name: 'enable', + get description() { + return t('Enable a disabled hook'); + }, + kind: CommandKind.BUILT_IN, + action: async ( + context: CommandContext, + args: string, + ): Promise => { + const hookName = args.trim(); + if (!hookName) { + return { + type: 'message', + messageType: 'error', + content: t( + 'Please specify a hook name. Usage: /hooks enable ', + ), + }; + } + + const { config } = context.services; + if (!config) { + return { + type: 'message', + messageType: 'error', + content: t('Config not loaded.'), + }; + } + + const hookSystem = config.getHookSystem(); + if (!hookSystem) { + return { + type: 'message', + messageType: 'error', + content: t('Hooks are not enabled.'), + }; + } + + const registry = hookSystem.getRegistry(); + registry.setHookEnabled(hookName, true); + + return { + type: 'message', + messageType: 'info', + content: t('Hook "{{name}}" has been enabled for this session.', { + name: hookName, + }), + }; + }, + completion: async (context: CommandContext, partialArg: string) => { + const { config } = context.services; + if (!config) return []; + + const hookSystem = config.getHookSystem(); + if (!hookSystem) return []; + + const registry = hookSystem.getRegistry(); + const allHooks = registry.getAllHooks(); + + // Return disabled hooks for enable command + return allHooks + .filter((hook) => !hook.enabled) + .map((hook) => hook.config.name || hook.config.command || '') + .filter((name) => name && name.startsWith(partialArg)); + }, +}; + +const disableCommand: SlashCommand = { + name: 'disable', + get description() { + return t('Disable an active hook'); + }, + kind: CommandKind.BUILT_IN, + action: async ( + context: CommandContext, + args: string, + ): Promise => { + const hookName = args.trim(); + if (!hookName) { + return { + type: 'message', + messageType: 'error', + content: t( + 'Please specify a hook name. Usage: /hooks disable ', + ), + }; + } + + const { config } = context.services; + if (!config) { + return { + type: 'message', + messageType: 'error', + content: t('Config not loaded.'), + }; + } + + const hookSystem = config.getHookSystem(); + if (!hookSystem) { + return { + type: 'message', + messageType: 'error', + content: t('Hooks are not enabled.'), + }; + } + + const registry = hookSystem.getRegistry(); + registry.setHookEnabled(hookName, false); + + return { + type: 'message', + messageType: 'info', + content: t('Hook "{{name}}" has been disabled for this session.', { + name: hookName, + }), + }; + }, + completion: async (context: CommandContext, partialArg: string) => { + const { config } = context.services; + if (!config) return []; + + const hookSystem = config.getHookSystem(); + if (!hookSystem) return []; + + const registry = hookSystem.getRegistry(); + const allHooks = registry.getAllHooks(); + + // Return enabled hooks for disable command + return allHooks + .filter((hook) => hook.enabled) + .map((hook) => hook.config.name || hook.config.command || '') + .filter((name) => name && name.startsWith(partialArg)); + }, +}; + +export const hooksCommand: SlashCommand = { + name: 'hooks', + get description() { + return t('Manage Qwen Code hooks'); + }, + kind: CommandKind.BUILT_IN, + subCommands: [listCommand, enableCommand, disableCommand], + action: async ( + context: CommandContext, + args: string, + ): Promise => { + // If no subcommand provided, show list + if (!args.trim()) { + const result = await listCommand.action?.(context, ''); + return result ?? { type: 'message', messageType: 'info', content: '' }; + } + + const [subcommand, ...rest] = args.trim().split(/\s+/); + const subArgs = rest.join(' '); + + let result: SlashCommandActionReturn | void; + switch (subcommand.toLowerCase()) { + case 'list': + result = await listCommand.action?.(context, subArgs); + break; + case 'enable': + result = await enableCommand.action?.(context, subArgs); + break; + case 'disable': + result = await disableCommand.action?.(context, subArgs); + break; + default: + return { + type: 'message', + messageType: 'error', + content: t( + 'Unknown subcommand: {{cmd}}. Available: list, enable, disable', + { + cmd: subcommand, + }, + ), + }; + } + return result ?? { type: 'message', messageType: 'info', content: '' }; + }, + completion: async (context: CommandContext, partialArg: string) => { + const subcommands = ['list', 'enable', 'disable']; + const parts = partialArg.split(/\s+/); + + if (parts.length <= 1) { + // Complete subcommand + return subcommands.filter((cmd) => cmd.startsWith(partialArg)); + } + + // Complete subcommand arguments + const [subcommand, ...rest] = parts; + const subArgs = rest.join(' '); + + switch (subcommand.toLowerCase()) { + case 'enable': + return enableCommand.completion?.(context, subArgs) ?? []; + case 'disable': + return disableCommand.completion?.(context, subArgs) ?? []; + default: + return []; + } + }, +}; diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 5bebbac7e..04006fabc 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -945,6 +945,15 @@ export const useGeminiStream = ( clearRetryCountdown(); } break; + case ServerGeminiEventType.HookSystemMessage: + // Display system message from hooks (e.g., Ralph Loop iteration info) + // This is handled as a content event to show in the UI + geminiMessageBuffer = handleContentEvent( + event.value + '\n', + geminiMessageBuffer, + userMessageTimestamp, + ); + break; default: { // enforces exhaustive switch-case const unreachable: never = event; diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index e1598a641..54a14b4bd 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -84,6 +84,14 @@ import { ExtensionManager, type Extension, } from '../extension/extensionManager.js'; +import { HookSystem } from '../hooks/index.js'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import { + MessageBusType, + type HookExecutionRequest, + type HookExecutionResponse, +} from '../confirmation-bus/types.js'; // Utils import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; @@ -378,6 +386,10 @@ export interface ConfigParameters { channel?: string; /** Model providers configuration grouped by authType */ modelProvidersConfig?: ModelProvidersConfig; + /** Enable hook system for lifecycle events */ + enableHooks?: boolean; + /** Hooks configuration from settings */ + hooks?: Record; } function normalizeConfigOutputFormat( @@ -518,6 +530,11 @@ export class Config { private readonly eventEmitter?: EventEmitter; private readonly channel: string | undefined; private readonly defaultFileEncoding: FileEncodingType; + private readonly enableHooks: boolean; + private readonly hooks?: Record; + private hookSystem?: HookSystem; + private messageBus?: MessageBus; + private policyEngine?: PolicyEngine; constructor(params: ConfigParameters) { this.sessionId = params.sessionId ?? randomUUID(); @@ -672,6 +689,8 @@ export class Config { enabledExtensionOverrides: this.overrideExtensions, isWorkspaceTrusted: this.isTrustedFolder(), }); + this.enableHooks = params.enableHooks ?? true; + this.hooks = params.hooks; } /** @@ -695,6 +714,77 @@ export class Config { await this.extensionManager.refreshCache(); this.debugLogger.debug('Extension manager initialized'); + // Initialize hook system if enabled + if (this.enableHooks) { + this.hookSystem = new HookSystem(this); + await this.hookSystem.initialize(); + this.debugLogger.debug('Hook system initialized'); + + // Initialize PolicyEngine and MessageBus for hook execution + this.policyEngine = new PolicyEngine(); + this.messageBus = new MessageBus(this.policyEngine); + + // Subscribe to HOOK_EXECUTION_REQUEST to execute hooks + this.messageBus.subscribe( + MessageBusType.HOOK_EXECUTION_REQUEST, + async (request: HookExecutionRequest) => { + try { + const hookSystem = this.hookSystem; + if (!hookSystem) { + this.messageBus?.publish({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: request.correlationId, + success: false, + error: new Error('Hook system not initialized'), + } as HookExecutionResponse); + return; + } + + // Execute the appropriate hook based on eventName + let result; + const input = request.input || {}; + switch (request.eventName) { + case 'UserPromptSubmit': + result = await hookSystem.fireUserPromptSubmitEvent( + (input['prompt'] as string) || '', + ); + break; + case 'Stop': + result = await hookSystem.fireStopEvent( + (input['prompt'] as string) || '', + (input['prompt_response'] as string) || '', + (input['stop_hook_active'] as boolean) || false, + ); + break; + default: + this.debugLogger.warn( + `Unknown hook event: ${request.eventName}`, + ); + result = undefined; + } + + // Send response + this.messageBus?.publish({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: request.correlationId, + success: true, + output: result, + } as HookExecutionResponse); + } catch (error) { + this.debugLogger.warn(`Hook execution failed: ${error}`); + this.messageBus?.publish({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: request.correlationId, + success: false, + error: error instanceof Error ? error : new Error(String(error)), + } as HookExecutionResponse); + } + }, + ); + + this.debugLogger.debug('MessageBus initialized with hook subscription'); + } + this.subagentManager = new SubagentManager(this); this.skillManager = new SkillManager(this); await this.skillManager.startWatching(); @@ -1374,6 +1464,81 @@ export class Config { return this.extensionManager; } + /** + * Get the hook system instance if hooks are enabled. + * Returns undefined if hooks are not enabled. + */ + getHookSystem(): HookSystem | undefined { + return this.hookSystem; + } + + /** + * Check if hooks are enabled. + */ + getEnableHooks(): boolean { + return this.enableHooks; + } + + /** + * Get the message bus instance. + * Returns undefined if not set. + */ + getMessageBus(): MessageBus | undefined { + return this.messageBus; + } + + /** + * Set the message bus instance. + * This is called by the CLI layer to inject the MessageBus. + */ + setMessageBus(messageBus: MessageBus): void { + this.messageBus = messageBus; + } + + /** + * Get the policy engine instance. + * Returns undefined if not set. + */ + getPolicyEngine(): PolicyEngine | undefined { + return this.policyEngine; + } + + /** + * Set the policy engine instance. + * This is called by the CLI layer to inject the PolicyEngine. + */ + setPolicyEngine(policyEngine: PolicyEngine): void { + this.policyEngine = policyEngine; + } + + /** + * Get the list of disabled hook names. + * This is used by the HookRegistry to filter out disabled hooks. + */ + getDisabledHooks(): string[] { + // This will be populated from settings by the CLI layer + // The core Config doesn't have direct access to settings + return []; + } + + /** + * Get project-level hooks configuration. + * This is used by the HookRegistry to load project-specific hooks. + */ + getProjectHooks(): Record | undefined { + // This will be populated from settings by the CLI layer + // The core Config doesn't have direct access to settings + return undefined; + } + + /** + * Get all hooks configuration (merged from all sources). + * This is used by the HookRegistry to load hooks. + */ + getHooks(): Record | undefined { + return this.hooks; + } + getExtensions(): Extension[] { const extensions = this.extensionManager.getLoadedExtensions(); if (this.overrideExtensions) { @@ -1614,6 +1779,21 @@ export class Config { return this.chatRecordingService; } + /** + * Returns the transcript file path for the current session. + * This is the path to the JSONL file where the conversation is recorded. + * Returns empty string if chat recording is disabled. + */ + getTranscriptPath(): string { + if (!this.chatRecordingEnabled) { + return ''; + } + const projectDir = this.storage.getProjectDir(); + const sessionId = this.getSessionId(); + const safeFilename = `${sessionId}.jsonl`; + return path.join(projectDir, 'chats', safeFilename); + } + /** * Gets or creates a SessionService for managing chat sessions. */ diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts new file mode 100644 index 000000000..235ef53d6 --- /dev/null +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -0,0 +1,206 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { randomUUID } from 'node:crypto'; +import { EventEmitter } from 'node:events'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import { PolicyDecision, getHookSource } from '../policy/types.js'; +import { + MessageBusType, + type Message, + type HookExecutionRequest, + type HookPolicyDecision, +} from './types.js'; +import { safeJsonStringify } from '../utils/safeJsonStringify.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; + +const debugLogger = createDebugLogger('TRUSTED_HOOKS'); + +export class MessageBus extends EventEmitter { + constructor( + private readonly policyEngine: PolicyEngine, + private readonly debug = false, + ) { + super(); + this.debug = debug; + } + + private isValidMessage(message: Message): boolean { + if (!message || !message.type) { + return false; + } + + if ( + message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST && + !('correlationId' in message) + ) { + return false; + } + + return true; + } + + private emitMessage(message: Message): void { + this.emit(message.type, message); + } + + async publish(message: Message): Promise { + if (this.debug) { + debugLogger.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`); + } + try { + if (!this.isValidMessage(message)) { + throw new Error( + `Invalid message structure: ${safeJsonStringify(message)}`, + ); + } + + if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + const { decision } = await this.policyEngine.check( + message.toolCall, + message.serverName, + ); + + switch (decision) { + case PolicyDecision.ALLOW: + // Directly emit the response instead of recursive publish + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: true, + }); + break; + case PolicyDecision.DENY: + // Emit both rejection and response messages + this.emitMessage({ + type: MessageBusType.TOOL_POLICY_REJECTION, + toolCall: message.toolCall, + }); + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: false, + }); + break; + case PolicyDecision.ASK_USER: + // Pass through to UI for user confirmation if any listeners exist. + // If no listeners are registered (e.g., headless/ACP flows), + // immediately request user confirmation to avoid long timeouts. + if ( + this.listenerCount(MessageBusType.TOOL_CONFIRMATION_REQUEST) > 0 + ) { + this.emitMessage(message); + } else { + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: false, + requiresUserConfirmation: true, + }); + } + break; + default: + throw new Error(`Unknown policy decision: ${decision}`); + } + } else if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) { + // Handle hook execution requests through policy evaluation + const hookRequest = message as HookExecutionRequest; + const decision = await this.policyEngine.checkHook(hookRequest); + + // Map decision to allow/deny for observability (ASK_USER treated as deny for hooks) + const effectiveDecision = + decision === PolicyDecision.ALLOW ? 'allow' : 'deny'; + + // Emit policy decision for observability + this.emitMessage({ + type: MessageBusType.HOOK_POLICY_DECISION, + eventName: hookRequest.eventName, + hookSource: getHookSource(hookRequest.input), + decision: effectiveDecision, + reason: + decision !== PolicyDecision.ALLOW + ? 'Hook execution denied by policy' + : undefined, + } as HookPolicyDecision); + + // If allowed, emit the request for hook system to handle + if (decision === PolicyDecision.ALLOW) { + this.emitMessage(message); + } else { + // If denied or ASK_USER, emit error response (hooks don't support interactive confirmation) + this.emitMessage({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: hookRequest.correlationId, + success: false, + error: new Error('Hook execution denied by policy'), + }); + } + } else { + // For all other message types, just emit them + this.emitMessage(message); + } + } catch (error) { + this.emit('error', error); + } + } + + subscribe( + type: T['type'], + listener: (message: T) => void, + ): void { + this.on(type, listener); + } + + unsubscribe( + type: T['type'], + listener: (message: T) => void, + ): void { + this.off(type, listener); + } + + /** + * Request-response pattern: Publish a message and wait for a correlated response + * This enables synchronous-style communication over the async MessageBus + * The correlation ID is generated internally and added to the request + */ + async request( + request: Omit, + responseType: TResponse['type'], + timeoutMs: number = 60000, + ): Promise { + const correlationId = randomUUID(); + + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + cleanup(); + reject(new Error(`Request timed out waiting for ${responseType}`)); + }, timeoutMs); + + const cleanup = () => { + clearTimeout(timeoutId); + this.unsubscribe(responseType, responseHandler); + }; + + const responseHandler = (response: TResponse) => { + // Check if this response matches our request + if ( + 'correlationId' in response && + response.correlationId === correlationId + ) { + cleanup(); + resolve(response); + } + }; + + // Subscribe to responses + this.subscribe(responseType, responseHandler); + + // Publish the request with correlation ID + + this.publish({ ...request, correlationId } as TRequest); + }); + } +} diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts new file mode 100644 index 000000000..824fdd4d7 --- /dev/null +++ b/packages/core/src/confirmation-bus/types.ts @@ -0,0 +1,212 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionCall } from '@google/genai'; +import type { + ToolConfirmationOutcome, + ToolConfirmationPayload, +} from '../tools/tools.js'; +import type { ToolCall } from '../core/coreToolScheduler.js'; + +export enum MessageBusType { + TOOL_CONFIRMATION_REQUEST = 'tool-confirmation-request', + TOOL_CONFIRMATION_RESPONSE = 'tool-confirmation-response', + TOOL_POLICY_REJECTION = 'tool-policy-rejection', + TOOL_EXECUTION_SUCCESS = 'tool-execution-success', + TOOL_EXECUTION_FAILURE = 'tool-execution-failure', + UPDATE_POLICY = 'update-policy', + TOOL_CALLS_UPDATE = 'tool-calls-update', + ASK_USER_REQUEST = 'ask-user-request', + ASK_USER_RESPONSE = 'ask-user-response', + HOOK_EXECUTION_REQUEST = 'hook-execution-request', + HOOK_EXECUTION_RESPONSE = 'hook-execution-response', + HOOK_POLICY_DECISION = 'hook-policy-decision', +} + +export interface ToolCallsUpdateMessage { + type: MessageBusType.TOOL_CALLS_UPDATE; + toolCalls: ToolCall[]; + schedulerId: string; +} + +export interface ToolConfirmationRequest { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST; + toolCall: FunctionCall; + correlationId: string; + serverName?: string; + /** + * Optional rich details for the confirmation UI (diffs, counts, etc.) + */ + details?: SerializableConfirmationDetails; +} + +export interface ToolConfirmationResponse { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE; + correlationId: string; + confirmed: boolean; + /** + * The specific outcome selected by the user. + * + * TODO: Make required after migration. + */ + outcome?: ToolConfirmationOutcome; + /** + * Optional payload (e.g., modified content for 'modify_with_editor'). + */ + payload?: ToolConfirmationPayload; + /** + * When true, indicates that policy decision was ASK_USER and the tool should + * show its legacy confirmation UI instead of auto-proceeding. + */ + requiresUserConfirmation?: boolean; +} + +/** + * Data-only versions of ToolCallConfirmationDetails for bus transmission. + */ +export type SerializableConfirmationDetails = + | { + type: 'info'; + title: string; + prompt: string; + urls?: string[]; + } + | { + type: 'edit'; + title: string; + fileName: string; + filePath: string; + fileDiff: string; + originalContent: string | null; + newContent: string; + isModifying?: boolean; + } + | { + type: 'exec'; + title: string; + command: string; + rootCommand: string; + rootCommands: string[]; + commands?: string[]; + } + | { + type: 'mcp'; + title: string; + serverName: string; + toolName: string; + toolDisplayName: string; + } + | { + type: 'ask_user'; + title: string; + questions: Question[]; + } + | { + type: 'exit_plan_mode'; + title: string; + planPath: string; + }; + +export interface UpdatePolicy { + type: MessageBusType.UPDATE_POLICY; + toolName: string; + persist?: boolean; + argsPattern?: string; + commandPrefix?: string | string[]; + mcpName?: string; +} + +export interface ToolPolicyRejection { + type: MessageBusType.TOOL_POLICY_REJECTION; + toolCall: FunctionCall; +} + +export interface ToolExecutionSuccess { + type: MessageBusType.TOOL_EXECUTION_SUCCESS; + toolCall: FunctionCall; + result: T; +} + +export interface ToolExecutionFailure { + type: MessageBusType.TOOL_EXECUTION_FAILURE; + toolCall: FunctionCall; + error: E; +} + +export interface HookExecutionRequest { + type: MessageBusType.HOOK_EXECUTION_REQUEST; + eventName: string; + input: Record; + correlationId: string; +} + +export interface HookExecutionResponse { + type: MessageBusType.HOOK_EXECUTION_RESPONSE; + correlationId: string; + success: boolean; + output?: Record; + error?: Error; +} + +export interface HookPolicyDecision { + type: MessageBusType.HOOK_POLICY_DECISION; + eventName: string; + hookSource: 'project' | 'user' | 'system' | 'extension'; + decision: 'allow' | 'deny'; + reason?: string; +} + +export interface QuestionOption { + label: string; + description: string; +} + +export enum QuestionType { + CHOICE = 'choice', + TEXT = 'text', + YESNO = 'yesno', +} + +export interface Question { + question: string; + header: string; + /** Question type: 'choice' renders selectable options, 'text' renders free-form input, 'yesno' renders a binary Yes/No choice. */ + type: QuestionType; + /** Selectable choices. REQUIRED when type='choice'. IGNORED for 'text' and 'yesno'. */ + options?: QuestionOption[]; + /** Allow multiple selections. Only applies when type='choice'. */ + multiSelect?: boolean; + /** Placeholder hint text. For type='text', shown in the input field. For type='choice', shown in the "Other" custom input. */ + placeholder?: string; +} + +export interface AskUserRequest { + type: MessageBusType.ASK_USER_REQUEST; + questions: Question[]; + correlationId: string; +} + +export interface AskUserResponse { + type: MessageBusType.ASK_USER_RESPONSE; + correlationId: string; + answers: { [questionIndex: string]: string }; + /** When true, indicates the user cancelled the dialog without submitting answers */ + cancelled?: boolean; +} + +export type Message = + | ToolConfirmationRequest + | ToolConfirmationResponse + | ToolPolicyRejection + | ToolExecutionSuccess + | ToolExecutionFailure + | UpdatePolicy + | AskUserRequest + | AskUserResponse + | ToolCallsUpdateMessage + | HookExecutionRequest + | HookExecutionResponse + | HookPolicyDecision; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 9f3625c38..4d77c0626 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -69,6 +69,12 @@ import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; import { flatMapTextParts } from '../utils/partUtils.js'; import { retryWithBackoff } from '../utils/retry.js'; +// Hook triggers +import { + fireUserPromptSubmitHook, + fireStopHook, +} from './clientHookTriggers.js'; + // IDE integration import { ideContextStore } from '../ide/ideContext.js'; import { type File, type IdeContext } from '../ide/types.js'; @@ -407,6 +413,35 @@ export class GeminiClient { options?: { isContinuation: boolean }, turns: number = MAX_TURNS, ): AsyncGenerator { + // Fire BeforeAgent hook through MessageBus (only if hooks are enabled) + const hooksEnabled = this.config.getEnableHooks(); + const messageBus = this.config.getMessageBus(); + if (hooksEnabled && messageBus) { + const hookOutput = await fireUserPromptSubmitHook(messageBus, request); + + if ( + hookOutput?.isBlockingDecision() || + hookOutput?.shouldStopExecution() + ) { + yield { + type: GeminiEventType.Error, + value: { + error: new Error( + `BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`, + ), + }, + }; + return new Turn(this.getChat(), prompt_id); + } + + // Add additional context from hooks to the request + const additionalContext = hookOutput?.getAdditionalContext(); + if (additionalContext) { + const requestArray = Array.isArray(request) ? request : [request]; + request = [...requestArray, { text: additionalContext }]; + } + } + if (!options?.isContinuation) { this.loopDetector.reset(prompt_id); this.lastPromptId = prompt_id; @@ -536,6 +571,50 @@ export class GeminiClient { return turn; } } + // Fire AfterAgent hook through MessageBus (only if hooks are enabled) + // This must be done before any early returns to ensure hooks are always triggered + if (hooksEnabled && messageBus && !turn.pendingToolCalls.length) { + // Get response text from the chat history + const history = this.getHistory(); + const lastModelMessage = history + .filter((msg) => msg.role === 'model') + .pop(); + const responseText = + lastModelMessage?.parts + ?.filter((p): p is { text: string } => 'text' in p) + .map((p) => p.text) + .join('') || '[no response text]'; + + const hookOutput = await fireStopHook(messageBus, request, responseText); + + // For AfterAgent hooks, blocking/stop execution should force continuation (like Stop Hook) + // This enables Ralph Loop functionality where the hook can: + // 1. Return {"decision": "block", "reason": ""} to continue with a new prompt + // 2. Optionally include "systemMessage" to display a status message + if ( + hookOutput?.isBlockingDecision() || + hookOutput?.shouldStopExecution() + ) { + // Emit system message if provided (e.g., "🔄 Ralph iteration 5") + if (hookOutput.systemMessage) { + yield { + type: GeminiEventType.HookSystemMessage, + value: hookOutput.systemMessage, + }; + } + + const continueReason = hookOutput.getEffectiveReason(); + const continueRequest = [{ text: continueReason }]; + return yield* this.sendMessageStream( + continueRequest, + signal, + prompt_id, + { isContinuation: true }, + boundedTurns - 1, + ); + } + } + if (!turn.pendingToolCalls.length && signal && !signal.aborted) { if (this.config.getSkipNextSpeakerCheck()) { return turn; @@ -557,9 +636,9 @@ export class GeminiClient { ); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; - // This recursive call's events will be yielded out, but the final - // turn object will be from the top-level call. - yield* this.sendMessageStream( + // This recursive call's events will be yielded out, and the final + // turn object from the recursive call will be returned. + return yield* this.sendMessageStream( nextRequest, signal, prompt_id, @@ -568,6 +647,7 @@ export class GeminiClient { ); } } + return turn; } diff --git a/packages/core/src/core/clientHookTriggers.ts b/packages/core/src/core/clientHookTriggers.ts new file mode 100644 index 000000000..02fce7621 --- /dev/null +++ b/packages/core/src/core/clientHookTriggers.ts @@ -0,0 +1,107 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { PartListUnion } from '@google/genai'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + MessageBusType, + type HookExecutionRequest, + type HookExecutionResponse, +} from '../confirmation-bus/types.js'; +import { createHookOutput, type DefaultHookOutput } from '../hooks/types.js'; +import { partToString } from '../utils/partUtils.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; + +const debugLogger = createDebugLogger('HOOK_TRIGGERS'); + +/** + * Fires the UserPromptSubmit hook and returns the hook output. + * This should be called before processing a user prompt. + * + * The caller can use the returned DefaultHookOutput methods: + * - isBlockingDecision() / shouldStopExecution() to check if blocked + * - getEffectiveReason() to get the blocking reason + * - getAdditionalContext() to get additional context to add + * + * @param messageBus The message bus to use for hook communication + * @param request The user's request (prompt) + * @returns The hook output, or undefined if no hook was executed or on error + */ +export async function fireUserPromptSubmitHook( + messageBus: MessageBus, + request: PartListUnion, +): Promise { + try { + const promptText = partToString(request); + + const response = await messageBus.request< + HookExecutionRequest, + HookExecutionResponse + >( + { + type: MessageBusType.HOOK_EXECUTION_REQUEST, + eventName: 'UserPromptSubmit', + input: { + prompt: promptText, + }, + }, + MessageBusType.HOOK_EXECUTION_RESPONSE, + ); + + return response.output + ? createHookOutput('UserPromptSubmit', response.output) + : undefined; + } catch (error) { + debugLogger.warn(`UserPromptSubmit hook failed: ${error}`); + return undefined; + } +} + +/** + * Fires the Stop hook and returns the hook output. + * This should be called after the agent has generated a response. + * + * The caller can use the returned DefaultHookOutput methods: + * - isBlockingDecision() / shouldStopExecution() to check if continuation is requested + * - getEffectiveReason() to get the continuation reason + * + * @param messageBus The message bus to use for hook communication + * @param request The original user's request (prompt) + * @param responseText The agent's response text + * @returns The hook output, or undefined if no hook was executed or on error + */ +export async function fireStopHook( + messageBus: MessageBus, + request: PartListUnion, + responseText: string, +): Promise { + try { + const promptText = partToString(request); + + const response = await messageBus.request< + HookExecutionRequest, + HookExecutionResponse + >( + { + type: MessageBusType.HOOK_EXECUTION_REQUEST, + eventName: 'Stop', + input: { + prompt: promptText, + prompt_response: responseText, + stop_hook_active: false, + }, + }, + MessageBusType.HOOK_EXECUTION_RESPONSE, + ); + + return response.output + ? createHookOutput('Stop', response.output) + : undefined; + } catch (error) { + debugLogger.warn(`Stop hook failed: ${error}`); + return undefined; + } +} diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 17c6c47de..3115cb425 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -64,6 +64,7 @@ export enum GeminiEventType { LoopDetected = 'loop_detected', Citation = 'citation', Retry = 'retry', + HookSystemMessage = 'hook_system_message', } export type ServerGeminiRetryEvent = { @@ -200,6 +201,11 @@ export type ServerGeminiCitationEvent = { value: string; }; +export type ServerGeminiHookSystemMessageEvent = { + type: GeminiEventType.HookSystemMessage; + value: string; +}; + // The original union type, now composed of the individual types export type ServerGeminiStreamEvent = | ServerGeminiChatCompressedEvent @@ -207,6 +213,7 @@ export type ServerGeminiStreamEvent = | ServerGeminiContentEvent | ServerGeminiErrorEvent | ServerGeminiFinishedEvent + | ServerGeminiHookSystemMessageEvent | ServerGeminiLoopDetectedEvent | ServerGeminiMaxSessionTurnsEvent | ServerGeminiThoughtEvent diff --git a/packages/core/src/extension/extensionManager.ts b/packages/core/src/extension/extensionManager.ts index 2da26995a..dd781d62b 100644 --- a/packages/core/src/extension/extensionManager.ts +++ b/packages/core/src/extension/extensionManager.ts @@ -100,6 +100,7 @@ export interface Extension { commands?: string[]; skills?: SkillConfig[]; agents?: SubagentConfig[]; + hooks?: Record; } export interface ExtensionConfig { diff --git a/packages/core/src/hooks/hookAggregator.ts b/packages/core/src/hooks/hookAggregator.ts new file mode 100644 index 000000000..467904427 --- /dev/null +++ b/packages/core/src/hooks/hookAggregator.ts @@ -0,0 +1,227 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + HookEventName, + DefaultHookOutput, + PreToolUseHookOutput, + StopHookOutput, +} from './types.js'; +import type { HookOutput, HookExecutionResult } from './types.js'; + +/** + * Aggregated result from multiple hook executions + */ +export interface AggregatedHookResult { + success: boolean; + allOutputs: HookOutput[]; + errors: Error[]; + totalDuration: number; + finalOutput?: HookOutput; +} + +/** + * HookAggregator merges multiple hook outputs using event-specific rules. + * + * Different events have different merging strategies: + * - PreToolUse/PostToolUse: OR logic for decisions, concatenation for messages + */ +export class HookAggregator { + /** + * Aggregate results from multiple hook executions + */ + aggregateResults( + results: HookExecutionResult[], + eventName: HookEventName, + ): AggregatedHookResult { + const allOutputs: HookOutput[] = []; + const errors: Error[] = []; + let totalDuration = 0; + + for (const result of results) { + totalDuration += result.duration; + + if (!result.success && result.error) { + errors.push(result.error); + } + + if (result.output) { + allOutputs.push(result.output); + } + } + + const success = errors.length === 0; + const finalOutput = this.mergeOutputs(allOutputs, eventName); + + return { + success, + allOutputs, + errors, + totalDuration, + finalOutput, + }; + } + + /** + * Merge multiple hook outputs based on event type + */ + private mergeOutputs( + outputs: HookOutput[], + eventName: HookEventName, + ): HookOutput | undefined { + if (outputs.length === 0) { + return undefined; + } + + if (outputs.length === 1) { + return this.createSpecificHookOutput(outputs[0], eventName); + } + + let merged: HookOutput; + + switch (eventName) { + case HookEventName.PreToolUse: + case HookEventName.PostToolUse: + merged = this.mergeWithOrLogic(outputs); + break; + + default: + merged = this.mergeSimple(outputs); + } + + return this.createSpecificHookOutput(merged, eventName); + } + + /** + * Merge outputs using OR logic for decisions and concatenation for messages. + * + * Rules: + * - Any "block" or "deny" decision results in blocking (most restrictive wins) + * - Reasons are concatenated with newlines + * - continue=false takes precedence over continue=true + * - Additional context is concatenated + */ + private mergeWithOrLogic(outputs: HookOutput[]): HookOutput { + const merged: HookOutput = {}; + const reasons: string[] = []; + const additionalContexts: string[] = []; + let hasBlock = false; + let hasContinueFalse = false; + let stopReason: string | undefined; + + for (const output of outputs) { + // Check for blocking decisions + if (output.decision === 'block' || output.decision === 'deny') { + hasBlock = true; + } + + // Collect reasons + if (output.reason) { + reasons.push(output.reason); + } + + // Check continue flag + if (output.continue === false) { + hasContinueFalse = true; + if (output.stopReason) { + stopReason = output.stopReason; + } + } + + // Extract additional context + this.extractAdditionalContext(output, additionalContexts); + + // Copy other fields (later values win for simple fields) + if (output.suppressOutput !== undefined) { + merged.suppressOutput = output.suppressOutput; + } + if (output.systemMessage !== undefined) { + merged.systemMessage = output.systemMessage; + } + } + + // Set merged decision + if (hasBlock) { + merged.decision = 'block'; + } else if (outputs.some((o) => o.decision === 'allow')) { + merged.decision = 'allow'; + } + + // Set merged reason + if (reasons.length > 0) { + merged.reason = reasons.join('\n'); + } + + // Set continue flag + if (hasContinueFalse) { + merged.continue = false; + if (stopReason) { + merged.stopReason = stopReason; + } + } + + // Set additional context if any + if (additionalContexts.length > 0) { + merged.hookSpecificOutput = { + ...merged.hookSpecificOutput, + additionalContext: additionalContexts.join('\n'), + }; + } + + return merged; + } + + /** + * Simple merge for events without special logic + */ + private mergeSimple(outputs: HookOutput[]): HookOutput { + let merged: HookOutput = {}; + + for (const output of outputs) { + merged = { ...merged, ...output }; + } + + return merged; + } + + /** + * Create the appropriate specific hook output class based on event type + */ + private createSpecificHookOutput( + output: HookOutput, + eventName: HookEventName, + ): DefaultHookOutput { + switch (eventName) { + case HookEventName.PreToolUse: + return new PreToolUseHookOutput(output); + case HookEventName.Stop: + return new StopHookOutput(output); + default: + return new DefaultHookOutput(output); + } + } + + /** + * Extract additional context from hook-specific outputs + */ + private extractAdditionalContext( + output: HookOutput, + contexts: string[], + ): void { + const specific = output.hookSpecificOutput; + if (!specific) { + return; + } + + // Extract additionalContext from various hook types + if ( + 'additionalContext' in specific && + typeof specific['additionalContext'] === 'string' + ) { + contexts.push(specific['additionalContext']); + } + } +} diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts new file mode 100644 index 000000000..dcb2cdfb5 --- /dev/null +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -0,0 +1,401 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import type { HookPlanner, HookEventContext } from './hookPlanner.js'; +import type { HookRunner } from './hookRunner.js'; +import type { HookAggregator, AggregatedHookResult } from './hookAggregator.js'; +import { HookEventName } from './types.js'; +import type { + HookConfig, + HookInput, + HookExecutionResult, + PreToolUseInput, + PostToolUseInput, + UserPromptSubmitInput, + NotificationInput, + StopInput, + SessionStartInput, + SessionEndInput, + PreCompactInput, + NotificationType, + SessionStartSource, + SessionEndReason, + PreCompactTrigger, + McpToolContext, +} from './types.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; + +const debugLogger = createDebugLogger('TRUSTED_HOOKS'); + +/** + * Hook event bus that coordinates hook execution across the system + */ +export class HookEventHandler { + private readonly config: Config; + private readonly hookPlanner: HookPlanner; + private readonly hookRunner: HookRunner; + private readonly hookAggregator: HookAggregator; + + /** + * Track reported failures to suppress duplicate warnings during streaming. + * Uses a WeakMap with the original request object as a key to ensure + * failures are only reported once per logical model interaction. + */ + private readonly reportedFailures = new WeakMap>(); + + constructor( + config: Config, + hookPlanner: HookPlanner, + hookRunner: HookRunner, + hookAggregator: HookAggregator, + ) { + this.config = config; + this.hookPlanner = hookPlanner; + this.hookRunner = hookRunner; + this.hookAggregator = hookAggregator; + } + + /** + * Fire a PreToolUse event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async firePreToolUseEvent( + toolName: string, + toolInput: Record, + mcpContext?: McpToolContext, + ): Promise { + const input: PreToolUseInput = { + ...this.createBaseInput(HookEventName.PreToolUse), + tool_name: toolName, + tool_input: toolInput, + ...(mcpContext && { mcp_context: mcpContext }), + }; + + const context: HookEventContext = { toolName }; + return this.executeHooks(HookEventName.PreToolUse, input, context); + } + + /** + * Fire a PostToolUse event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async firePostToolUseEvent( + toolName: string, + toolInput: Record, + toolResponse: Record, + mcpContext?: McpToolContext, + ): Promise { + const input: PostToolUseInput = { + ...this.createBaseInput(HookEventName.PostToolUse), + tool_name: toolName, + tool_input: toolInput, + tool_response: toolResponse, + ...(mcpContext && { mcp_context: mcpContext }), + }; + + const context: HookEventContext = { toolName }; + return this.executeHooks(HookEventName.PostToolUse, input, context); + } + + /** + * Fire a UserPromptSubmit event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireUserPromptSubmitEvent( + prompt: string, + ): Promise { + const input: UserPromptSubmitInput = { + ...this.createBaseInput(HookEventName.UserPromptSubmit), + prompt, + }; + + return this.executeHooks(HookEventName.UserPromptSubmit, input); + } + + /** + * Fire a Notification event + */ + async fireNotificationEvent( + type: NotificationType, + message: string, + details: Record, + ): Promise { + const input: NotificationInput = { + ...this.createBaseInput(HookEventName.Notification), + notification_type: type, + message, + details, + }; + + return this.executeHooks(HookEventName.Notification, input); + } + + /** + * Fire a Stop event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireStopEvent( + prompt: string, + promptResponse: string, + stopHookActive: boolean = false, + ): Promise { + const input: StopInput = { + ...this.createBaseInput(HookEventName.Stop), + prompt, + prompt_response: promptResponse, + stop_hook_active: stopHookActive, + }; + + return this.executeHooks(HookEventName.Stop, input); + } + + /** + * Fire a SessionStart event + */ + async fireSessionStartEvent( + source: SessionStartSource, + ): Promise { + const input: SessionStartInput = { + ...this.createBaseInput(HookEventName.SessionStart), + source, + }; + + const context: HookEventContext = { trigger: source }; + return this.executeHooks(HookEventName.SessionStart, input, context); + } + + /** + * Fire a SessionEnd event + */ + async fireSessionEndEvent( + reason: SessionEndReason, + ): Promise { + const input: SessionEndInput = { + ...this.createBaseInput(HookEventName.SessionEnd), + reason, + }; + + const context: HookEventContext = { trigger: reason }; + return this.executeHooks(HookEventName.SessionEnd, input, context); + } + + /** + * Fire a PreCompact event + */ + async firePreCompactEvent( + trigger: PreCompactTrigger, + ): Promise { + const input: PreCompactInput = { + ...this.createBaseInput(HookEventName.PreCompact), + trigger, + }; + + const context: HookEventContext = { trigger }; + return this.executeHooks(HookEventName.PreCompact, input, context); + } + + /** + * Execute hooks for a specific event (direct execution without MessageBus) + * Used as fallback when MessageBus is not available + */ + private async executeHooks( + eventName: HookEventName, + input: HookInput, + context?: HookEventContext, + requestContext?: object, + ): Promise { + try { + // Create execution plan + const plan = this.hookPlanner.createExecutionPlan(eventName, context); + + if (!plan || plan.hookConfigs.length === 0) { + return { + success: true, + allOutputs: [], + errors: [], + totalDuration: 0, + }; + } + + const onHookStart = (_config: HookConfig, _index: number) => { + // Hook start event (telemetry removed) + }; + + const onHookEnd = (_config: HookConfig, _result: HookExecutionResult) => { + // Hook end event (telemetry removed) + }; + + // Execute hooks according to the plan's strategy + const results = plan.sequential + ? await this.hookRunner.executeHooksSequential( + plan.hookConfigs, + eventName, + input, + onHookStart, + onHookEnd, + ) + : await this.hookRunner.executeHooksParallel( + plan.hookConfigs, + eventName, + input, + onHookStart, + onHookEnd, + ); + + // Aggregate results + const aggregated = this.hookAggregator.aggregateResults( + results, + eventName, + ); + + // Process common hook output fields centrally + this.processCommonHookOutputFields(aggregated); + + // Log hook execution + this.logHookExecution( + eventName, + input, + results, + aggregated, + requestContext, + ); + + return aggregated; + } catch (error) { + debugLogger.error(`Hook event bus error for ${eventName}: ${error}`); + + return { + success: false, + allOutputs: [], + errors: [error instanceof Error ? error : new Error(String(error))], + totalDuration: 0, + }; + } + } + + /** + * Create base hook input with common fields + */ + private createBaseInput(eventName: HookEventName): HookInput { + // Get the transcript path from the Config + const transcriptPath = this.config.getTranscriptPath(); + + return { + session_id: this.config.getSessionId(), + transcript_path: transcriptPath, + cwd: this.config.getWorkingDir(), + hook_event_name: eventName, + timestamp: new Date().toISOString(), + }; + } + + /** + * Log hook execution for observability + */ + private logHookExecution( + eventName: HookEventName, + input: HookInput, + results: HookExecutionResult[], + aggregated: AggregatedHookResult, + requestContext?: object, + ): void { + const failedHooks = results.filter((r) => !r.success); + const successCount = results.length - failedHooks.length; + const errorCount = failedHooks.length; + + if (errorCount > 0) { + const failedNames = failedHooks + .map((r) => this.getHookNameFromResult(r)) + .join(', '); + + let shouldEmit = true; + if (requestContext) { + let reportedSet = this.reportedFailures.get(requestContext); + if (!reportedSet) { + reportedSet = new Set(); + this.reportedFailures.set(requestContext, reportedSet); + } + + const failureKey = `${eventName}:${failedNames}`; + if (reportedSet.has(failureKey)) { + shouldEmit = false; + } else { + reportedSet.add(failureKey); + } + } + + debugLogger.warn( + `Hook execution for ${eventName}: ${successCount} succeeded, ${errorCount} failed (${failedNames}), ` + + `total duration: ${aggregated.totalDuration}ms`, + ); + + if (shouldEmit) { + debugLogger.warn( + `Hook(s) [${failedNames}] failed for event ${eventName}. Check debug logs for more details.`, + ); + } + } else { + debugLogger.debug( + `Hook execution for ${eventName}: ${successCount} hooks executed successfully, ` + + `total duration: ${aggregated.totalDuration}ms`, + ); + } + + // Log individual errors + for (const error of aggregated.errors) { + debugLogger.warn(`Hook execution error: ${error.message}`); + } + } + + /** + * Process common hook output fields centrally + */ + private processCommonHookOutputFields( + aggregated: AggregatedHookResult, + ): void { + if (!aggregated.finalOutput) { + return; + } + + // Handle systemMessage - show to user in transcript mode (not to agent) + const systemMessage = aggregated.finalOutput.systemMessage; + if (systemMessage && !aggregated.finalOutput.suppressOutput) { + debugLogger.warn(`Hook system message: ${systemMessage}`); + } + + // Handle suppressOutput - already handled by not logging above when true + + // Handle continue=false - this should stop the entire agent execution + if (aggregated.finalOutput.continue === false) { + const stopReason = + aggregated.finalOutput.stopReason || + aggregated.finalOutput.reason || + 'No reason provided'; + debugLogger.debug(`Hook requested to stop execution: ${stopReason}`); + + // Note: The actual stopping of execution must be handled by integration points + // as they need to interpret this signal in the context of their specific workflow + // This is just logging the request centrally + } + + // Other common fields like decision/reason are handled by specific hook output classes + } + + /** + * Get hook name from config for display or telemetry + */ + private getHookName(config: HookConfig): string { + return config.name || config.command || 'unknown-command'; + } + + /** + * Get hook name from execution result for telemetry + */ + private getHookNameFromResult(result: HookExecutionResult): string { + return this.getHookName(result.hookConfig); + } +} diff --git a/packages/core/src/hooks/hookPlanner.ts b/packages/core/src/hooks/hookPlanner.ts new file mode 100644 index 000000000..d460390c3 --- /dev/null +++ b/packages/core/src/hooks/hookPlanner.ts @@ -0,0 +1,140 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { HookRegistry, HookRegistryEntry } from './hookRegistry.js'; +import type { HookExecutionPlan } from './types.js'; +import { getHookKey, type HookEventName } from './types.js'; + +/** + * Hook planner that selects matching hooks and creates execution plans + */ +export class HookPlanner { + private readonly hookRegistry: HookRegistry; + + constructor(hookRegistry: HookRegistry) { + this.hookRegistry = hookRegistry; + } + + /** + * Create execution plan for a hook event + */ + createExecutionPlan( + eventName: HookEventName, + context?: HookEventContext, + ): HookExecutionPlan | null { + const hookEntries = this.hookRegistry.getHooksForEvent(eventName); + + if (hookEntries.length === 0) { + return null; + } + + // Filter hooks by matcher + const matchingEntries = hookEntries.filter((entry) => + this.matchesContext(entry, context), + ); + + if (matchingEntries.length === 0) { + return null; + } + + // Deduplicate identical hooks + const deduplicatedEntries = this.deduplicateHooks(matchingEntries); + + // Extract hook configs + const hookConfigs = deduplicatedEntries.map((entry) => entry.config); + + // Determine execution strategy - if ANY hook definition has sequential=true, run all sequentially + const sequential = deduplicatedEntries.some( + (entry) => entry.sequential === true, + ); + + const plan: HookExecutionPlan = { + eventName, + hookConfigs, + sequential, + }; + + return plan; + } + + /** + * Check if a hook entry matches the given context + */ + private matchesContext( + entry: HookRegistryEntry, + context?: HookEventContext, + ): boolean { + if (!entry.matcher || !context) { + return true; // No matcher means match all + } + + const matcher = entry.matcher.trim(); + + if (matcher === '' || matcher === '*') { + return true; // Empty string or wildcard matches all + } + + // For tool events, match against tool name + if (context.toolName) { + return this.matchesToolName(matcher, context.toolName); + } + + // For other events, match against trigger/source + if (context.trigger) { + return this.matchesTrigger(matcher, context.trigger); + } + + return true; + } + + /** + * Match tool name against matcher pattern + */ + private matchesToolName(matcher: string, toolName: string): boolean { + try { + // Attempt to treat the matcher as a regular expression. + const regex = new RegExp(matcher); + return regex.test(toolName); + } catch { + // If it's not a valid regex, treat it as a literal string for an exact match. + return matcher === toolName; + } + } + + /** + * Match trigger/source against matcher pattern + */ + private matchesTrigger(matcher: string, trigger: string): boolean { + return matcher === trigger; + } + + /** + * Deduplicate identical hook configurations + */ + private deduplicateHooks(entries: HookRegistryEntry[]): HookRegistryEntry[] { + const seen = new Set(); + const deduplicated: HookRegistryEntry[] = []; + + for (const entry of entries) { + const key = getHookKey(entry.config); + + if (!seen.has(key)) { + seen.add(key); + deduplicated.push(entry); + } + } + + return deduplicated; + } +} + +/** + * Context information for hook event matching + */ +export interface HookEventContext { + toolName?: string; + trigger?: string; +} diff --git a/packages/core/src/hooks/hookRegistry.ts b/packages/core/src/hooks/hookRegistry.ts new file mode 100644 index 000000000..548da5c44 --- /dev/null +++ b/packages/core/src/hooks/hookRegistry.ts @@ -0,0 +1,337 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { HookDefinition, HookConfig } from './types.js'; +import { + HookEventName, + HooksConfigSource, + HOOKS_CONFIG_FIELDS, +} from './types.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; +import { TrustedHooksManager } from './trustedHooks.js'; + +const debugLogger = createDebugLogger('HOOK_REGISTRY'); + +/** + * Extension with hooks support + */ +export interface ExtensionWithHooks { + isActive: boolean; + hooks?: { [K in HookEventName]?: HookDefinition[] }; +} + +/** + * Configuration interface for HookRegistry + * This abstracts the Config dependency to make the registry more flexible + */ +export interface HookRegistryConfig { + getProjectRoot(): string; + isTrustedFolder(): boolean; + getHooks(): { [K in HookEventName]?: HookDefinition[] } | undefined; + getProjectHooks(): { [K in HookEventName]?: HookDefinition[] } | undefined; + getDisabledHooks(): string[]; + getExtensions(): ExtensionWithHooks[]; +} + +/** + * Feedback emitter interface for warning/info messages + */ +export interface FeedbackEmitter { + emitFeedback(type: 'warning' | 'info' | 'error', message: string): void; +} + +/** + * Hook registry entry with source information + */ +export interface HookRegistryEntry { + config: HookConfig; + source: HooksConfigSource; + eventName: HookEventName; + matcher?: string; + sequential?: boolean; + enabled: boolean; +} + +/** + * Hook registry that loads and validates hook definitions from multiple sources + */ +export class HookRegistry { + private readonly config: HookRegistryConfig; + private readonly feedbackEmitter?: FeedbackEmitter; + private entries: HookRegistryEntry[] = []; + + constructor(config: HookRegistryConfig, feedbackEmitter?: FeedbackEmitter) { + this.config = config; + this.feedbackEmitter = feedbackEmitter; + } + + /** + * Initialize the registry by processing hooks from config + */ + async initialize(): Promise { + this.entries = []; + this.processHooksFromConfig(); + + debugLogger.debug( + `Hook registry initialized with ${this.entries.length} hook entries`, + ); + } + + /** + * Get all hook entries for a specific event + */ + getHooksForEvent(eventName: HookEventName): HookRegistryEntry[] { + return this.entries + .filter((entry) => entry.eventName === eventName && entry.enabled) + .sort( + (a, b) => + this.getSourcePriority(a.source) - this.getSourcePriority(b.source), + ); + } + + /** + * Get all registered hooks + */ + getAllHooks(): HookRegistryEntry[] { + return [...this.entries]; + } + + /** + * Enable or disable a specific hook + */ + setHookEnabled(hookName: string, enabled: boolean): void { + const updated = this.entries.filter((entry) => { + const name = this.getHookName(entry); + if (name === hookName) { + entry.enabled = enabled; + return true; + } + return false; + }); + + if (updated.length > 0) { + debugLogger.info( + `${enabled ? 'Enabled' : 'Disabled'} ${updated.length} hook(s) matching "${hookName}"`, + ); + } else { + debugLogger.warn(`No hooks found matching "${hookName}"`); + } + } + + /** + * Get hook name for identification and display purposes + */ + private getHookName( + entry: HookRegistryEntry | { config: HookConfig }, + ): string { + return entry.config.name || entry.config.command || 'unknown-command'; + } + + /** + * Check for untrusted project hooks and warn the user + */ + private checkProjectHooksTrust(): void { + const projectHooks = this.config.getProjectHooks(); + if (!projectHooks) return; + + try { + const trustedHooksManager = new TrustedHooksManager(); + const untrusted = trustedHooksManager.getUntrustedHooks( + this.config.getProjectRoot(), + projectHooks, + ); + + if (untrusted.length > 0) { + const message = `WARNING: The following project-level hooks have been detected in this workspace: +${untrusted.map((h: string) => ` - ${h}`).join('\n')} + +These hooks will be executed. If you did not configure these hooks or do not trust this project, +please review the project settings (.qwen/settings.json) and remove them.`; + this.feedbackEmitter?.emitFeedback('warning', message); + + // Trust them so we don't warn again + trustedHooksManager.trustHooks( + this.config.getProjectRoot(), + projectHooks, + ); + } + } catch { + debugLogger.warn('Failed to check project hooks trust'); + } + } + + /** + * Process hooks from the config that was already loaded by the CLI + */ + private processHooksFromConfig(): void { + if (this.config.isTrustedFolder()) { + this.checkProjectHooksTrust(); + } + + // Get hooks from the main config (this comes from the merged settings) + const configHooks = this.config.getHooks(); + if (configHooks) { + if (this.config.isTrustedFolder()) { + this.processHooksConfiguration(configHooks, HooksConfigSource.Project); + } else { + debugLogger.warn( + 'Project hooks disabled because the folder is not trusted.', + ); + } + } + + // Get hooks from extensions + const extensions = this.config.getExtensions() || []; + for (const extension of extensions) { + if (extension.isActive && extension.hooks) { + this.processHooksConfiguration( + extension.hooks, + HooksConfigSource.Extensions, + ); + } + } + } + + /** + * Process hooks configuration and add entries + */ + private processHooksConfiguration( + hooksConfig: { [K in HookEventName]?: HookDefinition[] }, + source: HooksConfigSource, + ): void { + for (const [eventName, definitions] of Object.entries(hooksConfig)) { + if (HOOKS_CONFIG_FIELDS.includes(eventName)) { + continue; + } + + if (!this.isValidEventName(eventName)) { + this.feedbackEmitter?.emitFeedback( + 'warning', + `Invalid hook event name: "${eventName}" from ${source} config. Skipping.`, + ); + continue; + } + + const typedEventName = eventName; + + if (!Array.isArray(definitions)) { + debugLogger.warn( + `Hook definitions for event "${eventName}" from source "${source}" is not an array. Skipping.`, + ); + continue; + } + + for (const definition of definitions) { + this.processHookDefinition(definition, typedEventName, source); + } + } + } + + /** + * Process a single hook definition + */ + private processHookDefinition( + definition: HookDefinition, + eventName: HookEventName, + source: HooksConfigSource, + ): void { + if ( + !definition || + typeof definition !== 'object' || + !Array.isArray(definition.hooks) + ) { + debugLogger.warn( + `Discarding invalid hook definition for ${eventName} from ${source}:`, + definition, + ); + return; + } + + // Get disabled hooks list from settings + const disabledHooks = this.config.getDisabledHooks(); + + for (const hookConfig of definition.hooks) { + if ( + hookConfig && + typeof hookConfig === 'object' && + this.validateHookConfig(hookConfig, eventName, source) + ) { + // Check if this hook is in the disabled list + const hookName = this.getHookName({ config: hookConfig }); + const isDisabled = disabledHooks.includes(hookName); + + // Add source to hook config + hookConfig.source = source; + + this.entries.push({ + config: hookConfig, + source, + eventName, + matcher: definition.matcher, + sequential: definition.sequential, + enabled: !isDisabled, + }); + } else { + // Invalid hooks are logged and discarded here, they won't reach HookRunner + debugLogger.warn( + `Discarding invalid hook configuration for ${eventName} from ${source}:`, + hookConfig, + ); + } + } + } + + /** + * Validate a hook configuration + */ + private validateHookConfig( + config: HookConfig, + eventName: HookEventName, + source: HooksConfigSource, + ): boolean { + if (!config.type || !['command', 'plugin'].includes(config.type)) { + debugLogger.warn( + `Invalid hook ${eventName} from ${source} type: ${config.type}`, + ); + return false; + } + + if (config.type === 'command' && !config.command) { + debugLogger.warn( + `Command hook ${eventName} from ${source} missing command field`, + ); + return false; + } + + return true; + } + + /** + * Check if an event name is valid + */ + private isValidEventName(eventName: string): eventName is HookEventName { + const validEventNames: string[] = Object.values(HookEventName); + return validEventNames.includes(eventName); + } + + /** + * Get source priority (lower number = higher priority) + */ + private getSourcePriority(source: HooksConfigSource): number { + switch (source) { + case HooksConfigSource.Project: + return 1; + case HooksConfigSource.User: + return 2; + case HooksConfigSource.System: + return 3; + case HooksConfigSource.Extensions: + return 4; + default: + return 999; + } + } +} diff --git a/packages/core/src/hooks/hookRunner.ts b/packages/core/src/hooks/hookRunner.ts new file mode 100644 index 000000000..c314b9015 --- /dev/null +++ b/packages/core/src/hooks/hookRunner.ts @@ -0,0 +1,451 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { spawn } from 'node:child_process'; +import { HookEventName, HooksConfigSource } from './types.js'; +import type { Config } from '../config/config.js'; +import type { + HookConfig, + HookInput, + HookOutput, + HookExecutionResult, + PreToolUseInput, + UserPromptSubmitInput, +} from './types.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; +import { + escapeShellArg, + getShellConfiguration, + type ShellType, +} from '../utils/shell-utils.js'; + +const debugLogger = createDebugLogger('TRUSTED_HOOKS'); + +/** + * Default timeout for hook execution (60 seconds) + */ +const DEFAULT_HOOK_TIMEOUT = 60000; + +/** + * Exit code constants for hook execution + */ +const EXIT_CODE_SUCCESS = 0; +const EXIT_CODE_NON_BLOCKING_ERROR = 1; + +/** + * Hook runner that executes command hooks + */ +export class HookRunner { + private readonly config: Config; + + constructor(config: Config) { + this.config = config; + } + + /** + * Execute a single hook + */ + async executeHook( + hookConfig: HookConfig, + eventName: HookEventName, + input: HookInput, + ): Promise { + const startTime = Date.now(); + + // Secondary security check: Ensure project hooks are not executed in untrusted folders + if ( + hookConfig.source === HooksConfigSource.Project && + !this.config.isTrustedFolder() + ) { + const errorMessage = + 'Security: Blocked execution of project hook in untrusted folder'; + debugLogger.warn(errorMessage); + return { + hookConfig, + eventName, + success: false, + error: new Error(errorMessage), + duration: 0, + }; + } + + try { + return await this.executeCommandHook( + hookConfig, + eventName, + input, + startTime, + ); + } catch (error) { + const duration = Date.now() - startTime; + const hookId = hookConfig.name || hookConfig.command || 'unknown'; + const errorMessage = `Hook execution failed for event '${eventName}' (hook: ${hookId}): ${error}`; + debugLogger.warn(`Hook execution error (non-fatal): ${errorMessage}`); + + return { + hookConfig, + eventName, + success: false, + error: error instanceof Error ? error : new Error(errorMessage), + duration, + }; + } + } + + /** + * Execute multiple hooks in parallel + */ + async executeHooksParallel( + hookConfigs: HookConfig[], + eventName: HookEventName, + input: HookInput, + onHookStart?: (config: HookConfig, index: number) => void, + onHookEnd?: (config: HookConfig, result: HookExecutionResult) => void, + ): Promise { + const promises = hookConfigs.map(async (config, index) => { + onHookStart?.(config, index); + const result = await this.executeHook(config, eventName, input); + onHookEnd?.(config, result); + return result; + }); + + return Promise.all(promises); + } + + /** + * Execute multiple hooks sequentially + */ + async executeHooksSequential( + hookConfigs: HookConfig[], + eventName: HookEventName, + input: HookInput, + onHookStart?: (config: HookConfig, index: number) => void, + onHookEnd?: (config: HookConfig, result: HookExecutionResult) => void, + ): Promise { + const results: HookExecutionResult[] = []; + let currentInput = input; + + for (let i = 0; i < hookConfigs.length; i++) { + const config = hookConfigs[i]; + onHookStart?.(config, i); + const result = await this.executeHook(config, eventName, currentInput); + onHookEnd?.(config, result); + results.push(result); + + // If the hook succeeded and has output, use it to modify the input for the next hook + if (result.success && result.output) { + currentInput = this.applyHookOutputToInput( + currentInput, + result.output, + eventName, + ); + } + } + + return results; + } + + /** + * Apply hook output to modify input for the next hook in sequential execution + */ + private applyHookOutputToInput( + originalInput: HookInput, + hookOutput: HookOutput, + eventName: HookEventName, + ): HookInput { + // Create a copy of the original input + const modifiedInput = { ...originalInput }; + + // Apply modifications based on hook output and event type + if (hookOutput.hookSpecificOutput) { + switch (eventName) { + case HookEventName.UserPromptSubmit: + if ('additionalContext' in hookOutput.hookSpecificOutput) { + // For UserPromptSubmit, we could modify the prompt with additional context + const additionalContext = + hookOutput.hookSpecificOutput['additionalContext']; + if ( + typeof additionalContext === 'string' && + 'prompt' in modifiedInput + ) { + (modifiedInput as UserPromptSubmitInput).prompt += + '\n\n' + additionalContext; + } + } + break; + + case HookEventName.PreToolUse: + if ('tool_input' in hookOutput.hookSpecificOutput) { + const newToolInput = hookOutput.hookSpecificOutput[ + 'tool_input' + ] as Record; + if (newToolInput && 'tool_input' in modifiedInput) { + (modifiedInput as PreToolUseInput).tool_input = { + ...(modifiedInput as PreToolUseInput).tool_input, + ...newToolInput, + }; + } + } + break; + + default: + // For other events, no special input modification is needed + break; + } + } + + return modifiedInput; + } + + /** + * Execute a command hook + */ + private async executeCommandHook( + hookConfig: HookConfig, + eventName: HookEventName, + input: HookInput, + startTime: number, + ): Promise { + const timeout = hookConfig.timeout ?? DEFAULT_HOOK_TIMEOUT; + + return new Promise((resolve) => { + if (!hookConfig.command) { + const errorMessage = 'Command hook missing command'; + debugLogger.warn( + `Hook configuration error (non-fatal): ${errorMessage}`, + ); + resolve({ + hookConfig, + eventName, + success: false, + error: new Error(errorMessage), + duration: Date.now() - startTime, + }); + return; + } + + let stdout = ''; + let stderr = ''; + let timedOut = false; + + const shellConfig = getShellConfiguration(); + const command = this.expandCommand( + hookConfig.command, + input, + shellConfig.shell, + ); + + // Set up environment variables + // Extract hook-specific fields from input to expose as environment variables + const hookEnvVars: Record = {}; + if ('prompt' in input && typeof input.prompt === 'string') { + hookEnvVars['PROMPT'] = input.prompt; + } + if ( + 'prompt_response' in input && + typeof input.prompt_response === 'string' + ) { + hookEnvVars['PROMPT_RESPONSE'] = input.prompt_response; + } + if ('tool_name' in input && typeof input.tool_name === 'string') { + hookEnvVars['TOOL_NAME'] = input.tool_name; + } + if ('session_id' in input && typeof input.session_id === 'string') { + hookEnvVars['SESSION_ID'] = input.session_id; + } + if ( + 'transcript_path' in input && + typeof input.transcript_path === 'string' + ) { + hookEnvVars['TRANSCRIPT_PATH'] = input.transcript_path; + } + if ( + 'stop_hook_active' in input && + typeof input.stop_hook_active === 'boolean' + ) { + hookEnvVars['STOP_HOOK_ACTIVE'] = input.stop_hook_active + ? 'true' + : 'false'; + } + + const env = { + ...process.env, + GEMINI_PROJECT_DIR: input.cwd, + CLAUDE_PROJECT_DIR: input.cwd, // For compatibility + QWEN_PROJECT_DIR: input.cwd, // For Qwen Code compatibility + ...hookEnvVars, + ...hookConfig.env, + }; + + const child = spawn( + shellConfig.executable, + [...shellConfig.argsPrefix, command], + { + env, + cwd: input.cwd, + stdio: ['pipe', 'pipe', 'pipe'], + shell: false, + }, + ); + + // Set up timeout + const timeoutHandle = setTimeout(() => { + timedOut = true; + child.kill('SIGTERM'); + + // Force kill after 5 seconds + setTimeout(() => { + if (!child.killed) { + child.kill('SIGKILL'); + } + }, 5000); + }, timeout); + + // Send input to stdin + if (child.stdin) { + child.stdin.on('error', (err: NodeJS.ErrnoException) => { + // Ignore EPIPE errors which happen when the child process closes stdin early + if (err.code !== 'EPIPE') { + debugLogger.debug(`Hook stdin error: ${err}`); + } + }); + + // Wrap write operations in try-catch to handle synchronous EPIPE errors + // that occur when the child process exits before we finish writing + try { + child.stdin.write(JSON.stringify(input)); + child.stdin.end(); + } catch (err) { + // Ignore EPIPE errors which happen when the child process closes stdin early + if (err instanceof Error && 'code' in err && err.code !== 'EPIPE') { + debugLogger.debug(`Hook stdin write error: ${err}`); + } + } + } + + // Collect stdout + child.stdout?.on('data', (data: Buffer) => { + stdout += data.toString(); + }); + + // Collect stderr + child.stderr?.on('data', (data: Buffer) => { + stderr += data.toString(); + }); + + // Handle process exit + child.on('close', (exitCode) => { + clearTimeout(timeoutHandle); + const duration = Date.now() - startTime; + + if (timedOut) { + resolve({ + hookConfig, + eventName, + success: false, + error: new Error(`Hook timed out after ${timeout}ms`), + stdout, + stderr, + duration, + }); + return; + } + + // Parse output + let output: HookOutput | undefined; + + const textToParse = stdout.trim() || stderr.trim(); + if (textToParse) { + try { + let parsed = JSON.parse(textToParse); + if (typeof parsed === 'string') { + parsed = JSON.parse(parsed); + } + if (parsed && typeof parsed === 'object') { + output = parsed as HookOutput; + } + } catch { + // Not JSON, convert plain text to structured output + output = this.convertPlainTextToHookOutput( + textToParse, + exitCode || EXIT_CODE_SUCCESS, + ); + } + } + + resolve({ + hookConfig, + eventName, + success: exitCode === EXIT_CODE_SUCCESS, + output, + stdout, + stderr, + exitCode: exitCode || EXIT_CODE_SUCCESS, + duration, + }); + }); + + // Handle process errors + child.on('error', (error) => { + clearTimeout(timeoutHandle); + const duration = Date.now() - startTime; + + resolve({ + hookConfig, + eventName, + success: false, + error, + stdout, + stderr, + duration, + }); + }); + }); + } + + /** + * Expand command with environment variables and input context + */ + private expandCommand( + command: string, + input: HookInput, + shellType: ShellType, + ): string { + debugLogger.debug(`Expanding hook command: ${command} (cwd: ${input.cwd})`); + const escapedCwd = escapeShellArg(input.cwd, shellType); + return command + .replace(/\$GEMINI_PROJECT_DIR/g, () => escapedCwd) + .replace(/\$CLAUDE_PROJECT_DIR/g, () => escapedCwd); // For compatibility + } + + /** + * Convert plain text output to structured HookOutput + */ + private convertPlainTextToHookOutput( + text: string, + exitCode: number, + ): HookOutput { + if (exitCode === EXIT_CODE_SUCCESS) { + // Success - treat as system message or additional context + return { + decision: 'allow', + systemMessage: text, + }; + } else if (exitCode === EXIT_CODE_NON_BLOCKING_ERROR) { + // Non-blocking error (EXIT_CODE_NON_BLOCKING_ERROR = 1) + return { + decision: 'allow', + systemMessage: `Warning: ${text}`, + }; + } else { + // All other non-zero exit codes (including 2) are blocking + return { + decision: 'deny', + reason: text, + }; + } + } +} diff --git a/packages/core/src/hooks/hookSystem.ts b/packages/core/src/hooks/hookSystem.ts new file mode 100644 index 000000000..fabe0cd23 --- /dev/null +++ b/packages/core/src/hooks/hookSystem.ts @@ -0,0 +1,270 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import { HookRegistry } from './hookRegistry.js'; +import { HookRunner } from './hookRunner.js'; +import { HookAggregator } from './hookAggregator.js'; +import { HookPlanner } from './hookPlanner.js'; +import { HookEventHandler } from './hookEventHandler.js'; +import type { HookRegistryEntry } from './hookRegistry.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; +import type { + SessionStartSource, + SessionEndReason, + PreCompactTrigger, + DefaultHookOutput, + McpToolContext, +} from './types.js'; +import { NotificationType, createHookOutput } from './types.js'; +import type { AggregatedHookResult } from './hookAggregator.js'; +import type { ToolCallConfirmationDetails } from '../tools/tools.js'; + +const debugLogger = createDebugLogger('TRUSTED_HOOKS'); + +/** + * Main hook system that coordinates all hook-related functionality + */ + +/** + * Converts ToolCallConfirmationDetails to a serializable format for hooks. + * Excludes function properties (onConfirm, ideConfirmation) that can't be serialized. + */ +function toSerializableDetails( + details: ToolCallConfirmationDetails, +): Record { + const base: Record = { + type: details.type, + title: details.title, + }; + + switch (details.type) { + case 'edit': + return { + ...base, + fileName: details.fileName, + filePath: details.filePath, + fileDiff: details.fileDiff, + originalContent: details.originalContent, + newContent: details.newContent, + isModifying: details.isModifying, + }; + case 'exec': + return { + ...base, + command: details.command, + rootCommand: details.rootCommand, + }; + case 'mcp': + return { + ...base, + serverName: details.serverName, + toolName: details.toolName, + toolDisplayName: details.toolDisplayName, + }; + case 'info': + return { + ...base, + prompt: details.prompt, + urls: details.urls, + }; + default: + return base; + } +} + +/** + * Gets the message to display in the notification hook for tool confirmation. + */ +function getNotificationMessage( + confirmationDetails: ToolCallConfirmationDetails, +): string { + switch (confirmationDetails.type) { + case 'edit': + return `Tool ${confirmationDetails.title} requires editing`; + case 'exec': + return `Tool ${confirmationDetails.title} requires execution`; + case 'mcp': + return `Tool ${confirmationDetails.title} requires MCP`; + case 'info': + return `Tool ${confirmationDetails.title} requires information`; + default: + return `Tool requires confirmation`; + } +} + +export class HookSystem { + private readonly hookRegistry: HookRegistry; + private readonly hookRunner: HookRunner; + private readonly hookAggregator: HookAggregator; + private readonly hookPlanner: HookPlanner; + private readonly hookEventHandler: HookEventHandler; + + constructor(config: Config) { + // Initialize components + this.hookRegistry = new HookRegistry(config); + this.hookRunner = new HookRunner(config); + this.hookAggregator = new HookAggregator(); + this.hookPlanner = new HookPlanner(this.hookRegistry); + this.hookEventHandler = new HookEventHandler( + config, + this.hookPlanner, + this.hookRunner, + this.hookAggregator, + ); + } + + /** + * Initialize the hook system + */ + async initialize(): Promise { + await this.hookRegistry.initialize(); + debugLogger.debug('Hook system initialized successfully'); + } + + /** + * Get the hook event bus for firing events + */ + getEventHandler(): HookEventHandler { + return this.hookEventHandler; + } + + /** + * Get hook registry for management operations + */ + getRegistry(): HookRegistry { + return this.hookRegistry; + } + + /** + * Enable or disable a hook + */ + setHookEnabled(hookName: string, enabled: boolean): void { + this.hookRegistry.setHookEnabled(hookName, enabled); + } + + /** + * Get all registered hooks for display/management + */ + getAllHooks(): HookRegistryEntry[] { + return this.hookRegistry.getAllHooks(); + } + + /** + * Fire hook events directly + */ + async fireSessionStartEvent( + source: SessionStartSource, + ): Promise { + const result = await this.hookEventHandler.fireSessionStartEvent(source); + return result.finalOutput + ? createHookOutput('SessionStart', result.finalOutput) + : undefined; + } + + async fireSessionEndEvent( + reason: SessionEndReason, + ): Promise { + return this.hookEventHandler.fireSessionEndEvent(reason); + } + + async firePreCompactEvent( + trigger: PreCompactTrigger, + ): Promise { + return this.hookEventHandler.firePreCompactEvent(trigger); + } + + async fireUserPromptSubmitEvent( + prompt: string, + ): Promise { + const result = + await this.hookEventHandler.fireUserPromptSubmitEvent(prompt); + return result.finalOutput + ? createHookOutput('UserPromptSubmit', result.finalOutput) + : undefined; + } + + async fireStopEvent( + prompt: string, + response: string, + stopHookActive: boolean = false, + ): Promise { + const result = await this.hookEventHandler.fireStopEvent( + prompt, + response, + stopHookActive, + ); + return result.finalOutput + ? createHookOutput('Stop', result.finalOutput) + : undefined; + } + + async firePreToolUseEvent( + toolName: string, + toolInput: Record, + mcpContext?: McpToolContext, + ): Promise { + try { + const result = await this.hookEventHandler.firePreToolUseEvent( + toolName, + toolInput, + mcpContext, + ); + return result.finalOutput + ? createHookOutput('PreToolUse', result.finalOutput) + : undefined; + } catch (error) { + debugLogger.debug(`PreToolUseEvent failed for ${toolName}:`, error); + return undefined; + } + } + + async firePostToolUseEvent( + toolName: string, + toolInput: Record, + toolResponse: { + llmContent: unknown; + returnDisplay: unknown; + error: unknown; + }, + mcpContext?: McpToolContext, + ): Promise { + try { + const result = await this.hookEventHandler.firePostToolUseEvent( + toolName, + toolInput, + toolResponse as Record, + mcpContext, + ); + return result.finalOutput + ? createHookOutput('PostToolUse', result.finalOutput) + : undefined; + } catch (error) { + debugLogger.debug(`PostToolUseEvent failed for ${toolName}:`, error); + return undefined; + } + } + + async fireToolNotificationEvent( + confirmationDetails: ToolCallConfirmationDetails, + ): Promise { + try { + const message = getNotificationMessage(confirmationDetails); + const serializedDetails = toSerializableDetails(confirmationDetails); + + await this.hookEventHandler.fireNotificationEvent( + NotificationType.ToolPermission, + message, + serializedDetails, + ); + } catch (error) { + debugLogger.debug( + `NotificationEvent failed for ${confirmationDetails.title}:`, + error, + ); + } + } +} diff --git a/packages/core/src/hooks/index.ts b/packages/core/src/hooks/index.ts new file mode 100644 index 000000000..620130d9f --- /dev/null +++ b/packages/core/src/hooks/index.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +// Export types +export * from './types.js'; + +// Export core components +export { HookSystem } from './hookSystem.js'; +export { HookRegistry } from './hookRegistry.js'; +export { HookRunner } from './hookRunner.js'; +export { HookAggregator } from './hookAggregator.js'; +export { HookPlanner } from './hookPlanner.js'; +export { HookEventHandler } from './hookEventHandler.js'; + +// Export interfaces and enums +export type { HookRegistryEntry } from './hookRegistry.js'; +export { HooksConfigSource as ConfigSource } from './types.js'; +export type { AggregatedHookResult } from './hookAggregator.js'; +export type { HookEventContext } from './hookPlanner.js'; diff --git a/packages/core/src/hooks/trustedHooks.ts b/packages/core/src/hooks/trustedHooks.ts new file mode 100644 index 000000000..04e93500f --- /dev/null +++ b/packages/core/src/hooks/trustedHooks.ts @@ -0,0 +1,118 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'node:fs'; +import * as path from 'node:path'; +import { Storage } from '../config/storage.js'; +import { + getHookKey, + type HookDefinition, + type HookEventName, +} from './types.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; + +const debugLogger = createDebugLogger('TRUSTED_HOOKS'); + +interface TrustedHooksConfig { + [projectPath: string]: string[]; // Array of trusted hook keys (name:command) +} + +export class TrustedHooksManager { + private configPath: string; + private trustedHooks: TrustedHooksConfig = {}; + + constructor() { + this.configPath = path.join( + Storage.getGlobalQwenDir(), + 'trusted_hooks.json', + ); + this.load(); + } + + private load(): void { + try { + if (fs.existsSync(this.configPath)) { + const content = fs.readFileSync(this.configPath, 'utf-8'); + this.trustedHooks = JSON.parse(content); + } + } catch (error) { + debugLogger.warn('Failed to load trusted hooks config', error); + this.trustedHooks = {}; + } + } + + private save(): void { + try { + const dir = path.dirname(this.configPath); + if (!fs.existsSync(dir)) { + fs.mkdirSync(dir, { recursive: true }); + } + fs.writeFileSync( + this.configPath, + JSON.stringify(this.trustedHooks, null, 2), + ); + } catch (error) { + debugLogger.warn('Failed to save trusted hooks config', error); + } + } + + /** + * Get untrusted hooks for a project + * @param projectPath Absolute path to the project root + * @param hooks The hooks configuration to check + * @returns List of untrusted hook commands/names + */ + getUntrustedHooks( + projectPath: string, + hooks: { [K in HookEventName]?: HookDefinition[] }, + ): string[] { + const trustedKeys = new Set(this.trustedHooks[projectPath] || []); + const untrusted: string[] = []; + + for (const eventName of Object.keys(hooks)) { + const definitions = hooks[eventName as HookEventName]; + if (!Array.isArray(definitions)) continue; + + for (const def of definitions) { + if (!def || !Array.isArray(def.hooks)) continue; + for (const hook of def.hooks) { + const key = getHookKey(hook); + if (!trustedKeys.has(key)) { + // Return friendly name or command + untrusted.push(hook.name || hook.command || 'unknown-hook'); + } + } + } + } + + return Array.from(new Set(untrusted)); // Deduplicate + } + + /** + * Trust all provided hooks for a project + */ + trustHooks( + projectPath: string, + hooks: { [K in HookEventName]?: HookDefinition[] }, + ): void { + const currentTrusted = new Set(this.trustedHooks[projectPath] || []); + + for (const eventName of Object.keys(hooks)) { + const definitions = hooks[eventName as HookEventName]; + if (!Array.isArray(definitions)) continue; + + for (const def of definitions) { + if (!def || !Array.isArray(def.hooks)) continue; + for (const hook of def.hooks) { + currentTrusted.add(getHookKey(hook)); + } + } + } + + this.trustedHooks[projectPath] = Array.from(currentTrusted); + this.save(); + } +} diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts new file mode 100644 index 000000000..45404eee0 --- /dev/null +++ b/packages/core/src/hooks/types.ts @@ -0,0 +1,461 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + ToolConfig as GenAIToolConfig, + ToolListUnion, +} from '@google/genai'; +export enum HooksConfigSource { + Project = 'project', + User = 'user', + System = 'system', + Extensions = 'extensions', +} + +/** + * Event names for the hook system + */ +export enum HookEventName { + PreToolUse = 'PreToolUse', + PostToolUse = 'PostToolUse', + UserPromptSubmit = 'UserPromptSubmit', + Notification = 'Notification', + Stop = 'Stop', + SessionStart = 'SessionStart', + SessionEnd = 'SessionEnd', + PreCompact = 'PreCompact', + SubagentStop = 'SubagentStop', + PermissionRequest = 'PermissionRequest', +} + +/** + * Fields in the hooks configuration that are not hook event names + */ +export const HOOKS_CONFIG_FIELDS = ['enabled', 'disabled', 'notifications']; + +/** + * Hook configuration entry + */ +export interface CommandHookConfig { + type: HookType.Command; + command: string; + name?: string; + description?: string; + timeout?: number; + source?: HooksConfigSource; + env?: Record; +} + +export type HookConfig = CommandHookConfig; + +/** + * Hook definition with matcher + */ +export interface HookDefinition { + matcher?: string; + sequential?: boolean; + hooks: HookConfig[]; +} + +/** + * Hook implementation types + */ +export enum HookType { + Command = 'command', +} + +/** + * Generate a unique key for a hook configuration + */ +export function getHookKey(hook: HookConfig): string { + const name = hook.name || ''; + const command = hook.command || ''; + return `${name}:${command}`; +} + +/** + * Decision types for hook outputs + */ +export type HookDecision = + | 'ask' + | 'block' + | 'deny' + | 'approve' + | 'allow' + | undefined; + +/** + * Base hook input - common fields for all events + */ +export interface HookInput { + session_id: string; + transcript_path: string; + cwd: string; + hook_event_name: string; + timestamp: string; +} + +/** + * Base hook output - common fields for all events + */ +export interface HookOutput { + continue?: boolean; + stopReason?: string; + suppressOutput?: boolean; + systemMessage?: string; + decision?: HookDecision; + reason?: string; + hookSpecificOutput?: Record; +} + +/** + * Factory function to create the appropriate hook output class based on event name + * Returns DefaultHookOutput for all events since it contains all necessary methods + */ +export function createHookOutput( + eventName: string, + data: Partial, +): DefaultHookOutput { + switch (eventName) { + case 'PreToolUse': + return new PreToolUseHookOutput(data); + case 'Stop': + return new StopHookOutput(data); + default: + return new DefaultHookOutput(data); + } +} + +/** + * Default implementation of HookOutput with utility methods + */ +export class DefaultHookOutput implements HookOutput { + continue?: boolean; + stopReason?: string; + suppressOutput?: boolean; + systemMessage?: string; + decision?: HookDecision; + reason?: string; + hookSpecificOutput?: Record; + + constructor(data: Partial = {}) { + this.continue = data.continue; + this.stopReason = data.stopReason; + this.suppressOutput = data.suppressOutput; + this.systemMessage = data.systemMessage; + this.decision = data.decision; + this.reason = data.reason; + this.hookSpecificOutput = data.hookSpecificOutput; + } + + /** + * Check if this output represents a blocking decision + */ + isBlockingDecision(): boolean { + return this.decision === 'block' || this.decision === 'deny'; + } + + /** + * Check if this output requests to stop execution + */ + shouldStopExecution(): boolean { + return this.continue === false; + } + + /** + * Get the effective reason for blocking or stopping + */ + getEffectiveReason(): string { + return this.stopReason || this.reason || 'No reason provided'; + } + + /** + * Apply tool config modifications (specific method for BeforeToolSelection hooks) + */ + applyToolConfigModifications(target: { + toolConfig?: GenAIToolConfig; + tools?: ToolListUnion; + }): { + toolConfig?: GenAIToolConfig; + tools?: ToolListUnion; + } { + // Base implementation - overridden by BeforeToolSelectionHookOutput + return target; + } + + /** + * Get sanitized additional context for adding to responses. + */ + getAdditionalContext(): string | undefined { + if ( + this.hookSpecificOutput && + 'additionalContext' in this.hookSpecificOutput + ) { + const context = this.hookSpecificOutput['additionalContext']; + if (typeof context !== 'string') { + return undefined; + } + + // Sanitize by escaping < and > to prevent tag injection + return context.replace(//g, '>'); + } + return undefined; + } + + /** + * Check if execution should be blocked and return error info + */ + getBlockingError(): { blocked: boolean; reason: string } { + if (this.isBlockingDecision()) { + return { + blocked: true, + reason: this.getEffectiveReason(), + }; + } + return { blocked: false, reason: '' }; + } + + /** + * Check if context clearing was requested by hook. + */ + shouldClearContext(): boolean { + return false; + } +} + +/** + * Specific hook output class for BeforeTool events. + */ +export class PreToolUseHookOutput extends DefaultHookOutput { + /** + * Get modified tool input if provided by hook + */ + getModifiedToolInput(): Record | undefined { + if (this.hookSpecificOutput && 'tool_input' in this.hookSpecificOutput) { + const input = this.hookSpecificOutput['tool_input']; + if ( + typeof input === 'object' && + input !== null && + !Array.isArray(input) + ) { + return input as Record; + } + } + return undefined; + } +} +export class StopHookOutput extends DefaultHookOutput { + override stopReason?: string; + + constructor(data: Partial = {}) { + super(data); + this.stopReason = data.stopReason; + } + + /** + * Get the stop reason if provided + */ + getStopReason(): string | undefined { + return this.stopReason; + } + + /** + * Check if context clearing was requested by hook + */ + override shouldClearContext(): boolean { + if (this.hookSpecificOutput && 'clearContext' in this.hookSpecificOutput) { + return this.hookSpecificOutput['clearContext'] === true; + } + return false; + } +} +/** + * Context for MCP tool executions. + * Contains non-sensitive connection information about the MCP server + * identity. Since server_name is user controlled and arbitrary, we + * also include connection information (e.g., command or url) to + * help identify the MCP server. + * + * NOTE: In the future, consider defining a shared sanitized interface + * from MCPServerConfig to avoid duplication and ensure consistency. + */ +export interface McpToolContext { + server_name: string; + tool_name: string; // Original tool name from the MCP server + + // Connection info (mutually exclusive based on transport type) + command?: string; // For stdio transport + args?: string[]; // For stdio transport + cwd?: string; // For stdio transport + + url?: string; // For SSE/HTTP transport + + tcp?: string; // For WebSocket transport +} + +export interface PreToolUseInput extends HookInput { + tool_name: string; + tool_input: Record; + mcp_context?: McpToolContext; +} + +/** + * BeforeTool hook output + */ +export interface BeforeToolOutput extends HookOutput { + hookSpecificOutput?: { + hookEventName: 'BeforeTool'; + tool_input?: Record; + }; +} +export interface PostToolUseInput extends HookInput { + tool_name: string; + tool_input: Record; + tool_response: Record; + mcp_context?: McpToolContext; +} +export interface PostToolUseOutput extends HookOutput { + hookEventName: 'PostToolUse'; +} +/** + * BeforeAgent hook input + */ +export interface UserPromptSubmitInput extends HookInput { + prompt: string; +} +export interface UserPromptSubmitOutput extends HookOutput { + additionalContext?: string; +} +/** + * Notification types + */ +export enum NotificationType { + ToolPermission = 'ToolPermission', +} + +/** + * Notification hook input + */ +export interface NotificationInput extends HookInput { + notification_type: NotificationType; + message: string; + details: Record; +} + +/** + * Notification hook output + */ +export interface NotificationOutput { + suppressOutput?: boolean; + systemMessage?: string; +} + +/** + * AfterAgent hook input + */ +export interface StopInput extends HookInput { + prompt: string; + prompt_response: string; + stop_hook_active: boolean; +} + +/** + * Stop hook output + */ +export interface StopOutput extends HookOutput { + stopReason?: string; +} + +/** + * SessionStart source types + */ +export enum SessionStartSource { + Startup = 'startup', + Resume = 'resume', + Clear = 'clear', +} + +/** + * SessionStart hook input + */ +export interface SessionStartInput extends HookInput { + source: SessionStartSource; +} + +/** + * SessionStart hook output + */ +export interface SessionStartOutput extends HookOutput { + hookSpecificOutput?: { + hookEventName: 'SessionStart'; + additionalContext?: string; + }; +} + +/** + * SessionEnd reason types + */ +export enum SessionEndReason { + Exit = 'exit', + Clear = 'clear', + Logout = 'logout', + PromptInputExit = 'prompt_input_exit', + Other = 'other', +} + +/** + * SessionEnd hook input + */ +export interface SessionEndInput extends HookInput { + reason: SessionEndReason; +} + +/** + * PreCompress trigger types + */ +export enum PreCompactTrigger { + Manual = 'manual', + Auto = 'auto', +} + +/** + * PreCompress hook input + */ +export interface PreCompactInput extends HookInput { + trigger: PreCompactTrigger; +} + +/** + * PreCompress hook output + */ +export interface PreCompressOutput { + suppressOutput?: boolean; + systemMessage?: string; +} + +/** + * Hook execution result + */ +export interface HookExecutionResult { + hookConfig: HookConfig; + eventName: HookEventName; + success: boolean; + output?: HookOutput; + stdout?: string; + stderr?: string; + exitCode?: number; + duration: number; + error?: Error; +} + +/** + * Hook execution plan for an event + */ +export interface HookExecutionPlan { + eventName: HookEventName; + hookConfigs: HookConfig[]; + sequential: boolean; +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index c76fd2f8d..0d961ba5e 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -298,3 +298,8 @@ export * from './qwen/qwenOAuth2.js'; export { makeFakeConfig } from './test-utils/config.js'; export * from './test-utils/index.js'; + +// Export hook types and components +export * from './hooks/types.js'; +export { HookSystem, HookRegistry } from './hooks/index.js'; +export type { HookRegistryEntry } from './hooks/index.js'; diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts new file mode 100644 index 000000000..131deac00 --- /dev/null +++ b/packages/core/src/policy/policy-engine.ts @@ -0,0 +1,541 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { FunctionCall } from '@google/genai'; +import stableStringify from 'fast-json-stable-stringify'; +import type { CheckerRunner } from '../safety/checker-runner.js'; +import { SafetyCheckDecision } from '../safety/protocol.js'; +import { + ApprovalMode, + PolicyDecision, + type CheckResult, + type HookCheckerRule, + type PolicyEngineConfig, + type PolicyRule, + type SafetyCheckerRule, +} from './types.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; + +const debugLogger = createDebugLogger('POLICY_ENGINE'); + +/** + * List of tool names that are considered shell commands. + */ +const SHELL_TOOL_NAMES = ['run_shell_command', 'shell', 'execute_command']; + +/** + * Check if a pattern is a wildcard pattern (contains * or ?). + */ +function isWildcardPattern(pattern: string): boolean { + return pattern.includes('*') || pattern.includes('?'); +} + +/** + * Match a tool name against a wildcard pattern. + */ +function matchesWildcard(pattern: string, toolName: string): boolean { + const regexPattern = pattern + .replace(/[.+^${}()|[\]\\]/g, '\\$&') + .replace(/\*/g, '.*') + .replace(/\?/g, '.'); + return new RegExp(`^${regexPattern}$`).test(toolName); +} + +/** + * Get all aliases for a tool name (for backwards compatibility). + */ +function getToolAliases(toolName: string): string[] { + const aliases: string[] = [toolName]; + + // Add common aliases + const aliasMap: Record = { + run_shell_command: ['shell', 'execute_command'], + shell: ['run_shell_command', 'execute_command'], + execute_command: ['run_shell_command', 'shell'], + }; + + if (aliasMap[toolName]) { + aliases.push(...aliasMap[toolName]); + } + + return aliases; +} + +/** + * Check if a rule matches a tool call. + */ +function ruleMatches( + rule: PolicyRule | SafetyCheckerRule, + toolCall: FunctionCall, + stringifiedArgs: string | undefined, + serverName: string | undefined, + approvalMode: ApprovalMode, +): boolean { + // Check approval mode + if ('modes' in rule && rule.modes && rule.modes.length > 0) { + if (!rule.modes.includes(approvalMode)) { + return false; + } + } + + // Check tool name + if (rule.toolName) { + const toolName = toolCall.name || ''; + + if (isWildcardPattern(rule.toolName)) { + if (!matchesWildcard(rule.toolName, toolName)) { + return false; + } + } else if (rule.toolName !== toolName) { + // Also check with server prefix + if (serverName && rule.toolName !== `${serverName}__${toolName}`) { + return false; + } else if (!serverName) { + return false; + } + } + } + + // Check args pattern + if (rule.argsPattern && stringifiedArgs) { + if (!rule.argsPattern.test(stringifiedArgs)) { + return false; + } + } + + return true; +} + +/** + * Policy engine for managing tool execution permissions. + */ +export class PolicyEngine { + private rules: PolicyRule[] = []; + private checkers: SafetyCheckerRule[] = []; + private hookCheckers: HookCheckerRule[] = []; + private readonly defaultDecision: PolicyDecision; + private readonly nonInteractive: boolean; + private readonly approvalMode: ApprovalMode; + private readonly checkerRunner?: CheckerRunner; + + constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) { + this.rules = [...(config.rules ?? [])]; + this.checkers = [...(config.checkers ?? [])]; + this.hookCheckers = [...(config.hookCheckers ?? [])]; + this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER; + this.nonInteractive = config.nonInteractive ?? false; + this.approvalMode = config.approvalMode ?? ApprovalMode.DEFAULT; + this.checkerRunner = checkerRunner; + + // Sort rules by priority (higher first) + this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Check shell command for additional security considerations. + */ + private async checkShellCommand( + toolName: string, + command: string | undefined, + ruleDecision: PolicyDecision, + serverName: string | undefined, + shellDirPath: string | undefined, + allowRedirection?: boolean, + rule?: PolicyRule, + ): Promise { + let aggregateDecision = ruleDecision; + let responsibleRule: PolicyRule | undefined; + + // Check for command redirection + if (command && !allowRedirection) { + const redirectionPatterns = [ + /[|&;`$()]/, + />\s*/, + /<\s*/, + /\$\(/, + /`[^`]*`/, + ]; + + for (const pattern of redirectionPatterns) { + if (pattern.test(command)) { + if (ruleDecision === PolicyDecision.ALLOW) { + debugLogger.debug( + `[PolicyEngine.checkShellCommand] Downgrading ALLOW to ASK_USER due to redirection pattern: ${pattern}`, + ); + aggregateDecision = PolicyDecision.ASK_USER; + break; + } + } + } + } + + return { + decision: this.applyNonInteractiveMode(aggregateDecision), + // If we stayed at ALLOW, we return the original rule (if any). + // If we downgraded, we return the responsible rule (or undefined if implicit). + rule: aggregateDecision === ruleDecision ? rule : responsibleRule, + }; + } + + /** + * Check if a tool call is allowed based on the configured policies. + * Returns the decision and the matching rule (if any). + */ + async check( + toolCall: FunctionCall, + serverName: string | undefined, + ): Promise { + let stringifiedArgs: string | undefined; + // Compute stringified args once before the loop + if ( + toolCall.args && + (this.rules.some((rule) => rule.argsPattern) || + this.checkers.some((checker) => checker.argsPattern)) + ) { + stringifiedArgs = stableStringify(toolCall.args); + } + + debugLogger.debug( + `[PolicyEngine.check] toolCall.name: ${toolCall.name}, stringifiedArgs: ${stringifiedArgs}`, + ); + + // Check for shell commands upfront to handle splitting + let isShellCommand = false; + let command: string | undefined; + let shellDirPath: string | undefined; + + const toolName = toolCall.name; + + if (toolName && SHELL_TOOL_NAMES.includes(toolName)) { + isShellCommand = true; + + const args = toolCall.args as { command?: string; dir_path?: string }; + command = args?.command; + shellDirPath = args?.dir_path; + } + + // Find the first matching rule (already sorted by priority) + let matchedRule: PolicyRule | undefined; + let decision: PolicyDecision | undefined; + + // For tools with a server name, we want to try matching both the + // original name and the fully qualified name (server__tool). + // We also want to check legacy aliases for the tool name. + const toolNamesToTry = toolCall.name ? getToolAliases(toolCall.name) : []; + + const toolCallsToTry: FunctionCall[] = []; + for (const name of toolNamesToTry) { + toolCallsToTry.push({ ...toolCall, name }); + if (serverName && !name.includes('__')) { + toolCallsToTry.push({ + ...toolCall, + name: `${serverName}__${name}`, + }); + } + } + + for (const rule of this.rules) { + const match = toolCallsToTry.some((tc) => + ruleMatches(rule, tc, stringifiedArgs, serverName, this.approvalMode), + ); + + if (match) { + debugLogger.debug( + `[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`, + ); + + if (isShellCommand && toolName) { + const shellResult = await this.checkShellCommand( + toolName, + command, + rule.decision, + serverName, + shellDirPath, + rule.allowRedirection, + rule, + ); + decision = shellResult.decision; + if (shellResult.rule) { + matchedRule = shellResult.rule; + break; + } + } else { + decision = this.applyNonInteractiveMode(rule.decision); + matchedRule = rule; + break; + } + } + } + + // Default if no rule matched + if (decision === undefined) { + debugLogger.debug( + `[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`, + ); + if (toolName && SHELL_TOOL_NAMES.includes(toolName)) { + const shellResult = await this.checkShellCommand( + toolName, + command, + this.defaultDecision, + serverName, + shellDirPath, + ); + decision = shellResult.decision; + matchedRule = shellResult.rule; + } else { + decision = this.applyNonInteractiveMode(this.defaultDecision); + } + } + + // Safety checks + if (decision !== PolicyDecision.DENY && this.checkerRunner) { + for (const checkerRule of this.checkers) { + if ( + ruleMatches( + checkerRule, + toolCall, + stringifiedArgs, + serverName, + this.approvalMode, + ) + ) { + debugLogger.debug( + `[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`, + ); + try { + const result = await this.checkerRunner.runChecker( + toolCall, + checkerRule.checker, + ); + if (result.decision === SafetyCheckDecision.DENY) { + debugLogger.debug( + `[PolicyEngine.check] Safety checker '${checkerRule.checker.name}' denied execution: ${result.reason}`, + ); + return { + decision: PolicyDecision.DENY, + rule: matchedRule, + }; + } else if (result.decision === SafetyCheckDecision.ASK_USER) { + debugLogger.debug( + `[PolicyEngine.check] Safety checker requested ASK_USER: ${result.reason}`, + ); + decision = PolicyDecision.ASK_USER; + } + } catch (error) { + debugLogger.debug( + `[PolicyEngine.check] Safety checker '${checkerRule.checker.name}' threw an error:`, + error, + ); + return { + decision: PolicyDecision.DENY, + rule: matchedRule, + }; + } + } + } + } + + return { + decision: this.applyNonInteractiveMode(decision), + rule: matchedRule, + }; + } + + /** + * Add a new rule to the policy engine. + */ + addRule(rule: PolicyRule): void { + this.rules.push(rule); + // Re-sort rules by priority + this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + addChecker(checker: SafetyCheckerRule): void { + this.checkers.push(checker); + this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Remove rules matching a specific tier (priority band). + */ + removeRulesByTier(tier: number): void { + this.rules = this.rules.filter( + (rule) => Math.floor(rule.priority ?? 0) !== tier, + ); + } + + /** + * Remove checkers matching a specific tier (priority band). + */ + removeCheckersByTier(tier: number): void { + this.checkers = this.checkers.filter( + (checker) => Math.floor(checker.priority ?? 0) !== tier, + ); + } + + /** + * Remove rules for a specific tool. + * If source is provided, only rules matching that source are removed. + */ + removeRulesForTool(toolName: string, source?: string): void { + this.rules = this.rules.filter( + (rule) => + rule.toolName !== toolName || + (source !== undefined && rule.source !== source), + ); + } + + /** + * Get all current rules. + */ + getRules(): readonly PolicyRule[] { + return this.rules; + } + + /** + * Check if a rule for a specific tool already exists. + * If ignoreDynamic is true, it only returns true if a rule exists that was NOT added by AgentRegistry. + */ + hasRuleForTool(toolName: string, ignoreDynamic = false): boolean { + return this.rules.some( + (rule) => + rule.toolName === toolName && + (!ignoreDynamic || rule.source !== 'AgentRegistry (Dynamic)'), + ); + } + + getCheckers(): readonly SafetyCheckerRule[] { + return this.checkers; + } + + /** + * Add a new hook checker to the policy engine. + */ + addHookChecker(checker: HookCheckerRule): void { + this.hookCheckers.push(checker); + this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Get all current hook checkers. + */ + getHookCheckers(): readonly HookCheckerRule[] { + return this.hookCheckers; + } + + /** + * Check if a hook execution is allowed based on the configured policies. + * Returns the decision for the hook execution request. + */ + async checkHook(hookRequest: { + eventName: string; + input: Record; + }): Promise { + debugLogger.debug( + `[PolicyEngine.checkHook] eventName: ${hookRequest.eventName}`, + ); + + // For now, allow all hooks by default + // In the future, this can be extended to check hook-specific policies + return this.applyNonInteractiveMode(PolicyDecision.ALLOW); + } + + /** + * Get tools that are effectively denied by the current rules. + * This takes into account: + * 1. Global rules (no argsPattern) + * 2. Priority order (higher priority wins) + * 3. Non-interactive mode (ASK_USER becomes DENY) + */ + getExcludedTools(): Set { + const excludedTools = new Set(); + const processedTools = new Set(); + let globalVerdict: PolicyDecision | undefined; + + for (const rule of this.rules) { + if (rule.argsPattern) { + if (rule.toolName && rule.decision !== PolicyDecision.DENY) { + processedTools.add(rule.toolName); + } + continue; + } + + // Check if rule applies to current approval mode + if (rule.modes && rule.modes.length > 0) { + if (!rule.modes.includes(this.approvalMode)) { + continue; + } + } + + // Handle Global Rules + if (!rule.toolName) { + if (globalVerdict === undefined) { + globalVerdict = rule.decision; + if (globalVerdict !== PolicyDecision.DENY) { + // Global ALLOW/ASK found. + // Since rules are sorted by priority, this overrides any lower-priority rules. + // We can stop processing because nothing else will be excluded. + break; + } + // If Global DENY, we continue to find specific tools to add to excluded set + } + continue; + } + + const toolName = rule.toolName; + + // Check if already processed (exact match) + if (processedTools.has(toolName)) { + continue; + } + + // Check if covered by a processed wildcard + let coveredByWildcard = false; + for (const processed of processedTools) { + if ( + isWildcardPattern(processed) && + matchesWildcard(processed, toolName) + ) { + // It's covered by a higher-priority wildcard rule. + // If that wildcard rule resulted in exclusion, this tool should also be excluded. + if (excludedTools.has(processed)) { + excludedTools.add(toolName); + } + coveredByWildcard = true; + break; + } + } + if (coveredByWildcard) { + continue; + } + + processedTools.add(toolName); + + // Determine decision + let decision: PolicyDecision; + if (globalVerdict !== undefined) { + decision = globalVerdict; + } else { + decision = rule.decision; + } + + if (decision === PolicyDecision.DENY) { + excludedTools.add(toolName); + } + } + return excludedTools; + } + + private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision { + // In non-interactive mode, ASK_USER becomes DENY + if (this.nonInteractive && decision === PolicyDecision.ASK_USER) { + return PolicyDecision.DENY; + } + return decision; + } +} diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts new file mode 100644 index 000000000..817da9788 --- /dev/null +++ b/packages/core/src/policy/types.ts @@ -0,0 +1,293 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { SafetyCheckInput } from '../safety/protocol.js'; + +export enum PolicyDecision { + ALLOW = 'allow', + DENY = 'deny', + ASK_USER = 'ask_user', +} + +/** + * Valid sources for hook execution + */ +export type HookSource = 'project' | 'user' | 'system' | 'extension'; + +/** + * Array of valid hook source values for runtime validation + */ +const VALID_HOOK_SOURCES: HookSource[] = [ + 'project', + 'user', + 'system', + 'extension', +]; + +/** + * Safely extract and validate hook source from input + * Returns 'project' as default if the value is invalid or missing + */ +export function getHookSource(input: Record): HookSource { + const source = input['hook_source']; + if ( + typeof source === 'string' && + VALID_HOOK_SOURCES.includes(source as HookSource) + ) { + return source as HookSource; + } + return 'project'; +} + +export enum ApprovalMode { + DEFAULT = 'default', + AUTO_EDIT = 'autoEdit', + YOLO = 'yolo', + PLAN = 'plan', +} + +/** + * Configuration for the built-in allowed-path checker. + */ +export interface AllowedPathConfig { + /** + * Explicitly include argument keys to be checked as paths. + */ + included_args?: string[]; + + /** + * Explicitly exclude argument keys from being checked as paths. + */ + excluded_args?: string[]; +} + +/** + * Base interface for external checkers. + */ +export interface ExternalCheckerConfig { + type: 'external'; + name: string; + config?: unknown; + required_context?: Array; +} + +export enum InProcessCheckerType { + ALLOWED_PATH = 'allowed-path', +} + +/** + * Base interface for in-process checkers. + */ +export interface InProcessCheckerConfig { + type: 'in-process'; + name: InProcessCheckerType; + config?: AllowedPathConfig; + required_context?: Array; +} + +/** + * A discriminated union for all safety checker configurations. + */ +export type SafetyCheckerConfig = + | ExternalCheckerConfig + | InProcessCheckerConfig; + +export interface PolicyRule { + /** + * A unique name for the policy rule, useful for identification and debugging. + */ + name?: string; + + /** + * The name of the tool this rule applies to. + * If undefined, the rule applies to all tools. + */ + toolName?: string; + + /** + * Pattern to match against tool arguments. + * Can be used for more fine-grained control. + */ + argsPattern?: RegExp; + + /** + * The decision to make when this rule matches. + */ + decision: PolicyDecision; + + /** + * Priority of this rule. Higher numbers take precedence. + * Default is 0. + */ + priority?: number; + + /** + * Approval modes this rule applies to. + * If undefined or empty, it applies to all modes. + */ + modes?: ApprovalMode[]; + + /** + * If true, allows command redirection even if the policy engine would normally + * downgrade ALLOW to ASK_USER for redirected commands. + * Only applies when decision is ALLOW. + */ + allowRedirection?: boolean; + + /** + * Effect of the rule's source. + * e.g. "my-policies.toml", "Settings (MCP Trusted)", etc. + */ + source?: string; + + /** + * Optional message to display when this rule results in a DENY decision. + * This message will be returned to the model/user. + */ + denyMessage?: string; +} + +export interface SafetyCheckerRule { + /** + * The name of the tool this rule applies to. + * If undefined, the rule applies to all tools. + */ + toolName?: string; + + /** + * Pattern to match against tool arguments. + * Can be used for more fine-grained control. + */ + argsPattern?: RegExp; + + /** + * Priority of this checker. Higher numbers run first. + * Default is 0. + */ + priority?: number; + + /** + * Specifies an external or built-in safety checker to execute for + * additional validation of a tool call. + */ + checker: SafetyCheckerConfig; + + /** + * Approval modes this rule applies to. + * If undefined or empty, it applies to all modes. + */ + modes?: ApprovalMode[]; + + /** + * Source of the rule. + * e.g. "my-policies.toml", "Workspace: project.toml", etc. + */ + source?: string; +} + +export interface HookExecutionContext { + eventName: string; + hookSource?: HookSource; + trustedFolder?: boolean; +} + +/** + * Rule for applying safety checkers to hook executions. + * Similar to SafetyCheckerRule but with hook-specific matching criteria. + */ +export interface HookCheckerRule { + /** + * The name of the hook event this rule applies to. + * If undefined, the rule applies to all hook events. + */ + eventName?: string; + + /** + * The source of hooks this rule applies to. + * If undefined, the rule applies to all hook sources. + */ + hookSource?: HookSource; + + /** + * Priority of this checker. Higher numbers run first. + * Default is 0. + */ + priority?: number; + + /** + * Specifies an external or built-in safety checker to execute for + * additional validation of a hook execution. + */ + checker: SafetyCheckerConfig; +} + +export interface PolicyEngineConfig { + /** + * List of policy rules to apply. + */ + rules?: PolicyRule[]; + + /** + * List of safety checkers to apply to tool calls. + */ + checkers?: SafetyCheckerRule[]; + + /** + * List of safety checkers to apply to hook executions. + */ + hookCheckers?: HookCheckerRule[]; + + /** + * Default decision when no rules match. + * Defaults to ASK_USER. + */ + defaultDecision?: PolicyDecision; + + /** + * Whether to allow tools in non-interactive mode. + * When true, ASK_USER decisions become DENY. + */ + nonInteractive?: boolean; + + /** + * Whether to allow hooks to execute. + * When false, all hooks are denied. + * Defaults to true. + */ + allowHooks?: boolean; + + /** + * Current approval mode. + * Used to filter rules that have specific 'modes' defined. + */ + approvalMode?: ApprovalMode; +} + +export interface PolicySettings { + mcp?: { + excluded?: string[]; + allowed?: string[]; + }; + tools?: { + exclude?: string[]; + allowed?: string[]; + }; + mcpServers?: Record; + // User provided policies that will replace the USER level policies in ~/.gemini/policies + policyPaths?: string[]; + workspacePoliciesDir?: string; +} + +export interface CheckResult { + decision: PolicyDecision; + rule?: PolicyRule; +} + +/** + * Priority for subagent tools (registered dynamically). + * Effective priority matching Tier 1 (Default) read-only tools. + */ +export const PRIORITY_SUBAGENT_TOOL = 1.05; diff --git a/packages/core/src/safety/built-in.ts b/packages/core/src/safety/built-in.ts new file mode 100644 index 000000000..72a22b7f6 --- /dev/null +++ b/packages/core/src/safety/built-in.ts @@ -0,0 +1,155 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js'; +import { SafetyCheckDecision } from './protocol.js'; +import type { AllowedPathConfig } from '../policy/types.js'; + +/** + * Interface for all in-process safety checkers. + */ +export interface InProcessChecker { + check(input: SafetyCheckInput): Promise; +} + +/** + * An in-process checker to validate file paths. + */ +export class AllowedPathChecker implements InProcessChecker { + async check(input: SafetyCheckInput): Promise { + const { toolCall, context } = input; + + const config = input.config as AllowedPathConfig | undefined; + + // Build list of allowed directories + const allowedDirs = [ + context.environment.cwd, + ...context.environment.workspaces, + ]; + + // Find all arguments that look like paths + const includedArgs = config?.included_args ?? []; + const excludedArgs = config?.excluded_args ?? []; + + const pathsToCheck = this.collectPathsToCheck( + toolCall.args, + includedArgs, + excludedArgs, + ); + + // Check each path + for (const { path: p, argName } of pathsToCheck) { + const resolvedPath = this.safelyResolvePath(p, context.environment.cwd); + + if (!resolvedPath) { + // If path cannot be resolved, deny it + return { + decision: SafetyCheckDecision.DENY, + reason: `Cannot resolve path "${p}" in argument "${argName}"`, + }; + } + + const isAllowed = allowedDirs.some((dir) => { + // Also resolve allowed directories to handle symlinks + const resolvedDir = this.safelyResolvePath( + dir, + context.environment.cwd, + ); + if (!resolvedDir) return false; + return this.isPathAllowed(resolvedPath, resolvedDir); + }); + + if (!isAllowed) { + return { + decision: SafetyCheckDecision.DENY, + reason: `Path "${p}" in argument "${argName}" is outside of the allowed workspace directories.`, + }; + } + } + + return { decision: SafetyCheckDecision.ALLOW }; + } + + private safelyResolvePath(inputPath: string, cwd: string): string | null { + try { + const resolved = path.resolve(cwd, inputPath); + + // Walk up the directory tree until we find a path that exists + let current = resolved; + // Stop at root (dirname(root) === root on many systems, or it becomes empty/'.' depending on implementation) + while (current && current !== path.dirname(current)) { + if (fs.existsSync(current)) { + const canonical = fs.realpathSync(current); + // Re-construct the full path from this canonical base + const relative = path.relative(current, resolved); + // path.join handles empty relative paths correctly (returns canonical) + return path.join(canonical, relative); + } + current = path.dirname(current); + } + + // Fallback if nothing exists (unlikely if root exists) + return resolved; + } catch (_error) { + return null; + } + } + + private isPathAllowed(targetPath: string, allowedDir: string): boolean { + const relative = path.relative(allowedDir, targetPath); + return ( + relative === '' || + (!relative.startsWith('..') && !path.isAbsolute(relative)) + ); + } + + private collectPathsToCheck( + args: unknown, + includedArgs: string[], + excludedArgs: string[], + prefix = '', + ): Array<{ path: string; argName: string }> { + const paths: Array<{ path: string; argName: string }> = []; + + if (typeof args !== 'object' || args === null) { + return paths; + } + + for (const [key, value] of Object.entries(args)) { + const fullKey = prefix ? `${prefix}.${key}` : key; + + if (excludedArgs.includes(fullKey)) { + continue; + } + + if (typeof value === 'string') { + if ( + includedArgs.includes(fullKey) || + key.includes('path') || + key.includes('directory') || + key.includes('file') || + key === 'source' || + key === 'destination' + ) { + paths.push({ path: value, argName: fullKey }); + } + } else if (typeof value === 'object') { + paths.push( + ...this.collectPathsToCheck( + value, + includedArgs, + excludedArgs, + fullKey, + ), + ); + } + } + + return paths; + } +} diff --git a/packages/core/src/safety/checker-runner.ts b/packages/core/src/safety/checker-runner.ts new file mode 100644 index 000000000..02f824d98 --- /dev/null +++ b/packages/core/src/safety/checker-runner.ts @@ -0,0 +1,305 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { spawn } from 'node:child_process'; +import type { FunctionCall } from '@google/genai'; +import type { + SafetyCheckerConfig, + InProcessCheckerConfig, + ExternalCheckerConfig, +} from '../policy/types.js'; +import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js'; +import { SafetyCheckDecision } from './protocol.js'; +import type { CheckerRegistry } from './registry.js'; +import type { ContextBuilder } from './context-builder.js'; +import { z } from 'zod'; + +const SafetyCheckResultSchema: z.ZodType = + z.discriminatedUnion('decision', [ + z.object({ + decision: z.literal(SafetyCheckDecision.ALLOW), + reason: z.string().optional(), + }), + z.object({ + decision: z.literal(SafetyCheckDecision.DENY), + reason: z.string().min(1), + }), + z.object({ + decision: z.literal(SafetyCheckDecision.ASK_USER), + reason: z.string().min(1), + }), + ]); + +/** + * Configuration for the checker runner. + */ +export interface CheckerRunnerConfig { + /** + * Maximum time (in milliseconds) to wait for a checker to complete. + * Default: 5000 (5 seconds) + */ + timeout?: number; + + /** + * Path to the directory containing external checkers. + */ + checkersPath: string; +} + +/** + * Service for executing safety checker processes. + */ +export class CheckerRunner { + private static readonly DEFAULT_TIMEOUT = 5000; // 5 seconds + + private readonly registry: CheckerRegistry; + private readonly contextBuilder: ContextBuilder; + private readonly timeout: number; + + constructor( + contextBuilder: ContextBuilder, + registry: CheckerRegistry, + config: CheckerRunnerConfig, + ) { + this.contextBuilder = contextBuilder; + this.registry = registry; + this.timeout = config.timeout ?? CheckerRunner.DEFAULT_TIMEOUT; + } + + /** + * Runs a safety checker and returns the result. + */ + async runChecker( + toolCall: FunctionCall, + checkerConfig: SafetyCheckerConfig, + ): Promise { + if (checkerConfig.type === 'in-process') { + return this.runInProcessChecker(toolCall, checkerConfig); + } + return this.runExternalChecker(toolCall, checkerConfig); + } + + private async runInProcessChecker( + toolCall: FunctionCall, + checkerConfig: InProcessCheckerConfig, + ): Promise { + try { + const checker = this.registry.resolveInProcess(checkerConfig.name); + const context = checkerConfig.required_context + ? this.contextBuilder.buildMinimalContext( + checkerConfig.required_context, + ) + : this.contextBuilder.buildFullContext(); + + const input: SafetyCheckInput = { + protocolVersion: '1.0.0', + toolCall, + context, + config: checkerConfig.config, + }; + + // In-process checkers can be async, but we'll also apply a timeout + // for safety, in case of infinite loops or unexpected delays. + return await this.executeWithTimeout(checker.check(input)); + } catch (error) { + return { + decision: SafetyCheckDecision.DENY, + reason: `Failed to run in-process checker "${checkerConfig.name}": ${ + error instanceof Error ? error.message : String(error) + }`, + }; + } + } + + private async runExternalChecker( + toolCall: FunctionCall, + checkerConfig: ExternalCheckerConfig, + ): Promise { + try { + // Resolve the checker executable path + const checkerPath = this.registry.resolveExternal(checkerConfig.name); + + // Build the appropriate context + const context = checkerConfig.required_context + ? this.contextBuilder.buildMinimalContext( + checkerConfig.required_context, + ) + : this.contextBuilder.buildFullContext(); + + // Create the input payload + const input: SafetyCheckInput = { + protocolVersion: '1.0.0', + toolCall, + context, + config: checkerConfig.config, + }; + + // Run the checker process + return await this.executeCheckerProcess( + checkerPath, + input, + checkerConfig.name, + ); + } catch (error) { + // If anything goes wrong, deny the operation + return { + decision: SafetyCheckDecision.DENY, + reason: `Failed to run safety checker "${checkerConfig.name}": ${ + error instanceof Error ? error.message : String(error) + }`, + }; + } + } + + /** + * Executes an external checker process and handles its lifecycle. + */ + private executeCheckerProcess( + checkerPath: string, + input: SafetyCheckInput, + checkerName: string, + ): Promise { + return new Promise((resolve) => { + const child = spawn(checkerPath, [], { + stdio: ['pipe', 'pipe', 'pipe'], + }); + + let stdout = ''; + let stderr = ''; + let timeoutHandle: NodeJS.Timeout | null = null; + let killed = false; + + let exited = false; + + // Set up timeout + timeoutHandle = setTimeout(() => { + killed = true; + child.kill('SIGTERM'); + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Safety checker "${checkerName}" timed out after ${this.timeout}ms`, + }); + + // Fallback: if process doesn't exit after 5s, force kill + setTimeout(() => { + if (!exited) { + child.kill('SIGKILL'); + } + }, 5000).unref(); + }, this.timeout); + + // Collect output + if (child.stdout) { + child.stdout.on('data', (data: Buffer) => { + stdout += data.toString(); + }); + } + + if (child.stderr) { + child.stderr.on('data', (data: Buffer) => { + stderr += data.toString(); + }); + } + + // Handle process completion + child.on('close', (code: number | null) => { + exited = true; + if (timeoutHandle) { + clearTimeout(timeoutHandle); + } + + // If we already killed it due to timeout, don't process the result + if (killed) { + return; + } + + // Non-zero exit code is a failure + if (code !== 0) { + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Safety checker "${checkerName}" exited with code ${code}${ + stderr ? `: ${stderr}` : '' + }`, + }); + return; + } + + // Try to parse the output + try { + const rawResult = JSON.parse(stdout); + const result = SafetyCheckResultSchema.parse(rawResult); + + resolve(result); + } catch (parseError) { + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Failed to parse output from safety checker "${checkerName}": ${ + parseError instanceof Error + ? parseError.message + : String(parseError) + }`, + }); + } + }); + + // Handle process errors + child.on('error', (error: Error) => { + if (timeoutHandle) { + clearTimeout(timeoutHandle); + } + + if (!killed) { + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Failed to spawn safety checker "${checkerName}": ${error.message}`, + }); + } + }); + + // Send input to the checker + try { + if (child.stdin) { + child.stdin.write(JSON.stringify(input)); + child.stdin.end(); + } else { + throw new Error('Failed to open stdin for checker process'); + } + } catch (writeError) { + if (timeoutHandle) { + clearTimeout(timeoutHandle); + } + + child.kill(); + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Failed to write to stdin of safety checker "${checkerName}": ${ + writeError instanceof Error + ? writeError.message + : String(writeError) + }`, + }); + } + }); + } + + /** + * Executes a promise with a timeout. + */ + private executeWithTimeout(promise: Promise): Promise { + return new Promise((resolve, reject) => { + const timeoutHandle = setTimeout(() => { + reject(new Error(`Checker timed out after ${this.timeout}ms`)); + }, this.timeout); + + promise + .then(resolve) + .catch(reject) + .finally(() => { + clearTimeout(timeoutHandle); + }); + }); + } +} diff --git a/packages/core/src/safety/context-builder.ts b/packages/core/src/safety/context-builder.ts new file mode 100644 index 000000000..134c857ad --- /dev/null +++ b/packages/core/src/safety/context-builder.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { SafetyCheckInput, ConversationTurn } from './protocol.js'; +import type { Config } from '../config/config.js'; + +/** + * Builds context objects for safety checkers, ensuring sensitive data is filtered. + */ +export class ContextBuilder { + constructor( + private readonly config: Config, + private readonly conversationHistory: ConversationTurn[] = [], + ) {} + + /** + * Builds the full context object with all available data. + */ + buildFullContext(): SafetyCheckInput['context'] { + return { + environment: { + cwd: process.cwd(), + + workspaces: this.config + .getWorkspaceContext() + .getDirectories() as string[], + }, + history: { + turns: this.conversationHistory, + }, + }; + } + + /** + * Builds a minimal context with only the specified keys. + */ + buildMinimalContext( + requiredKeys: Array, + ): SafetyCheckInput['context'] { + const fullContext = this.buildFullContext(); + const minimalContext: Partial = {}; + + for (const key of requiredKeys) { + if (key in fullContext) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (minimalContext as any)[key] = fullContext[key]; + } + } + + return minimalContext as SafetyCheckInput['context']; + } +} diff --git a/packages/core/src/safety/protocol.ts b/packages/core/src/safety/protocol.ts new file mode 100644 index 000000000..5028bd689 --- /dev/null +++ b/packages/core/src/safety/protocol.ts @@ -0,0 +1,100 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { FunctionCall } from '@google/genai'; + +/** + * Represents a single turn in the conversation between the user and the model. + * This provides semantic context for why a tool call might be happening. + */ +export interface ConversationTurn { + user: { + text: string; + }; + model: { + text?: string; + toolCalls?: FunctionCall[]; + }; +} + +/** + * The data structure passed from the CLI to a safety checker process via stdin. + */ +export interface SafetyCheckInput { + /** + * The semantic version of the protocol (e.g., "1.0.0"). This allows + * for introducing breaking changes in the future while maintaining + * support for older checkers. + */ + protocolVersion: '1.0.0'; + + /** + * The specific tool call that is being validated. + */ + toolCall: FunctionCall; + + /** + * A container for all contextual information from the CLI's internal state. + * By grouping data into categories, we can easily add new context in the + * future without creating a flat, unmanageable object. + */ + context: { + /** + * Information about the user's file system and execution environment. + */ + environment: { + cwd: string; + workspaces: string[]; // A list of user-configured workspace roots + }; + + /** + * The recent history of the conversation. This can be used by checkers + * that need to understand the intent behind a tool call. + */ + history?: { + turns: ConversationTurn[]; + }; + }; + + /** + * Configuration for the safety checker. + * This allows checkers to be parameterized (e.g. allowed paths). + */ + config?: unknown; +} + +/** + * The possible decisions a safety checker can make. + */ +export enum SafetyCheckDecision { + ALLOW = 'allow', + DENY = 'deny', + ASK_USER = 'ask_user', +} + +/** + * The data structure returned by a safety checker process via stdout. + */ +export type SafetyCheckResult = + | { + /** + * The decision made by the safety checker. + */ + decision: SafetyCheckDecision.ALLOW; + /** + * If not allowed, a message explaining why the tool call was blocked. + * This will be shown to the user. + */ + reason?: string; + } + | { + decision: SafetyCheckDecision.DENY; + reason: string; + } + | { + decision: SafetyCheckDecision.ASK_USER; + reason: string; + }; diff --git a/packages/core/src/safety/registry.ts b/packages/core/src/safety/registry.ts new file mode 100644 index 000000000..2775a82fd --- /dev/null +++ b/packages/core/src/safety/registry.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import { type InProcessChecker, AllowedPathChecker } from './built-in.js'; +import { InProcessCheckerType } from '../policy/types.js'; + +/** + * Registry for managing safety checker resolution. + */ +export class CheckerRegistry { + private static readonly BUILT_IN_EXTERNAL_CHECKERS = new Map([ + // No external built-ins for now + ]); + + private static readonly BUILT_IN_IN_PROCESS_CHECKERS = new Map< + string, + InProcessChecker + >([[InProcessCheckerType.ALLOWED_PATH, new AllowedPathChecker()]]); + + // Regex to validate checker names (alphanumeric and hyphens only) + private static readonly VALID_NAME_PATTERN = /^[a-z0-9-]+$/; + + constructor(private readonly checkersPath: string) {} + + /** + * Resolves an external checker name to an absolute executable path. + */ + resolveExternal(name: string): string { + if (!CheckerRegistry.isValidCheckerName(name)) { + throw new Error( + `Invalid checker name "${name}". Checker names must contain only lowercase letters, numbers, and hyphens.`, + ); + } + + const builtInPath = CheckerRegistry.BUILT_IN_EXTERNAL_CHECKERS.get(name); + if (builtInPath) { + const fullPath = path.join(this.checkersPath, builtInPath); + if (!fs.existsSync(fullPath)) { + throw new Error(`Built-in checker "${name}" not found at ${fullPath}`); + } + return fullPath; + } + + // TODO: Phase 5 - Add support for custom external checkers + throw new Error(`Unknown external checker "${name}".`); + } + + /** + * Resolves an in-process checker name to a checker instance. + */ + resolveInProcess(name: string): InProcessChecker { + if (!CheckerRegistry.isValidCheckerName(name)) { + throw new Error(`Invalid checker name "${name}".`); + } + + const checker = CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.get(name); + if (checker) { + return checker; + } + + throw new Error( + `Unknown in-process checker "${name}". Available: ${Array.from( + CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.keys(), + ).join(', ')}`, + ); + } + + private static isValidCheckerName(name: string): boolean { + return this.VALID_NAME_PATTERN.test(name) && !name.includes('..'); + } + + static getBuiltInCheckers(): string[] { + return [ + ...Array.from(this.BUILT_IN_EXTERNAL_CHECKERS.keys()), + ...Array.from(this.BUILT_IN_IN_PROCESS_CHECKERS.keys()), + ]; + } +}