mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-05-21 10:24:35 +00:00
refactor stop hook
This commit is contained in:
parent
a0a0a70b12
commit
43d64e26ca
32 changed files with 5387 additions and 6 deletions
25
packages/cli/src/commands/hooks.tsx
Normal file
25
packages/cli/src/commands/hooks.tsx
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { CommandModule } from 'yargs';
|
||||
import { enableCommand } from './hooks/enable.js';
|
||||
import { disableCommand } from './hooks/disable.js';
|
||||
|
||||
export const hooksCommand: CommandModule = {
|
||||
command: 'hooks <command>',
|
||||
aliases: ['hook'],
|
||||
describe: 'Manage Qwen Code hooks.',
|
||||
builder: (yargs) =>
|
||||
yargs
|
||||
.command(enableCommand)
|
||||
.command(disableCommand)
|
||||
.demandCommand(1, 'You need at least one command before continuing.')
|
||||
.version(false),
|
||||
handler: () => {
|
||||
// This handler is not called when a subcommand is provided.
|
||||
// Yargs will show the help menu.
|
||||
},
|
||||
};
|
||||
75
packages/cli/src/commands/hooks/disable.ts
Normal file
75
packages/cli/src/commands/hooks/disable.ts
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { CommandModule } from 'yargs';
|
||||
import { createDebugLogger, getErrorMessage } from '@qwen-code/qwen-code-core';
|
||||
import { loadSettings, SettingScope } from '../../config/settings.js';
|
||||
|
||||
const debugLogger = createDebugLogger('HOOKS_DISABLE');
|
||||
|
||||
interface DisableArgs {
|
||||
hookName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable a hook by adding it to the disabled list
|
||||
*/
|
||||
export async function handleDisableHook(hookName: string): Promise<void> {
|
||||
const workingDir = process.cwd();
|
||||
const settings = loadSettings(workingDir);
|
||||
|
||||
try {
|
||||
// Get current hooks settings
|
||||
const mergedSettings = settings.merged as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
const hooksSettings = (mergedSettings?.['hooks'] || {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const disabledHooks = (hooksSettings['disabled'] || []) as string[];
|
||||
|
||||
// Check if hook is already disabled
|
||||
if (disabledHooks.includes(hookName)) {
|
||||
debugLogger.info(`Hook "${hookName}" is already disabled.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Add hook to disabled list
|
||||
const newDisabledHooks = [...disabledHooks, hookName];
|
||||
const newHooksSettings = {
|
||||
...hooksSettings,
|
||||
disabled: newDisabledHooks,
|
||||
};
|
||||
|
||||
// Save updated settings
|
||||
settings.setValue(
|
||||
SettingScope.Workspace,
|
||||
'hooks' as keyof typeof settings.merged,
|
||||
newHooksSettings as never,
|
||||
);
|
||||
|
||||
debugLogger.info(`✓ Hook "${hookName}" has been disabled.`);
|
||||
} catch (error) {
|
||||
debugLogger.error(`Error disabling hook: ${getErrorMessage(error)}`);
|
||||
}
|
||||
}
|
||||
|
||||
export const disableCommand: CommandModule = {
|
||||
command: 'disable <hook-name>',
|
||||
describe: 'Disable an active hook',
|
||||
builder: (yargs) =>
|
||||
yargs.positional('hook-name', {
|
||||
describe: 'Name of the hook to disable',
|
||||
type: 'string',
|
||||
demandOption: true,
|
||||
}),
|
||||
handler: async (argv) => {
|
||||
const args = argv as unknown as DisableArgs;
|
||||
await handleDisableHook(args.hookName);
|
||||
process.exit(0);
|
||||
},
|
||||
};
|
||||
75
packages/cli/src/commands/hooks/enable.ts
Normal file
75
packages/cli/src/commands/hooks/enable.ts
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { CommandModule } from 'yargs';
|
||||
import { createDebugLogger, getErrorMessage } from '@qwen-code/qwen-code-core';
|
||||
import { loadSettings, SettingScope } from '../../config/settings.js';
|
||||
|
||||
const debugLogger = createDebugLogger('HOOKS_ENABLE');
|
||||
|
||||
interface EnableArgs {
|
||||
hookName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a hook by removing it from the disabled list
|
||||
*/
|
||||
export async function handleEnableHook(hookName: string): Promise<void> {
|
||||
const workingDir = process.cwd();
|
||||
const settings = loadSettings(workingDir);
|
||||
|
||||
try {
|
||||
// Get current hooks settings
|
||||
const mergedSettings = settings.merged as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
const hooksSettings = (mergedSettings?.['hooks'] || {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const disabledHooks = (hooksSettings['disabled'] || []) as string[];
|
||||
|
||||
// Check if hook is in disabled list
|
||||
if (!disabledHooks.includes(hookName)) {
|
||||
debugLogger.info(`Hook "${hookName}" is not disabled.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove hook from disabled list
|
||||
const newDisabledHooks = disabledHooks.filter((h) => h !== hookName);
|
||||
const newHooksSettings = {
|
||||
...hooksSettings,
|
||||
disabled: newDisabledHooks,
|
||||
};
|
||||
|
||||
// Save updated settings
|
||||
settings.setValue(
|
||||
SettingScope.Workspace,
|
||||
'hooks' as keyof typeof settings.merged,
|
||||
newHooksSettings as never,
|
||||
);
|
||||
|
||||
debugLogger.info(`✓ Hook "${hookName}" has been enabled.`);
|
||||
} catch (error) {
|
||||
debugLogger.error(`Error enabling hook: ${getErrorMessage(error)}`);
|
||||
}
|
||||
}
|
||||
|
||||
export const enableCommand: CommandModule = {
|
||||
command: 'enable <hook-name>',
|
||||
describe: 'Enable a disabled hook',
|
||||
builder: (yargs) =>
|
||||
yargs.positional('hook-name', {
|
||||
describe: 'Name of the hook to enable',
|
||||
type: 'string',
|
||||
demandOption: true,
|
||||
}),
|
||||
handler: async (argv) => {
|
||||
const args = argv as unknown as EnableArgs;
|
||||
await handleEnableHook(args.hookName);
|
||||
process.exit(0);
|
||||
},
|
||||
};
|
||||
|
|
@ -33,6 +33,7 @@ import {
|
|||
NativeLspService,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { extensionsCommand } from '../commands/extensions.js';
|
||||
import { hooksCommand } from '../commands/hooks.js';
|
||||
import type { Settings } from './settings.js';
|
||||
import {
|
||||
resolveCliGenerationConfig,
|
||||
|
|
@ -569,7 +570,9 @@ export async function parseArguments(): Promise<CliArgs> {
|
|||
// Register MCP subcommands
|
||||
.command(mcpCommand)
|
||||
// Register Extension subcommands
|
||||
.command(extensionsCommand);
|
||||
.command(extensionsCommand)
|
||||
// Register Hooks subcommands
|
||||
.command(hooksCommand);
|
||||
|
||||
yargsInstance
|
||||
.version(await getCliVersion()) // This will enable the --version flag based on package.json
|
||||
|
|
@ -588,9 +591,11 @@ export async function parseArguments(): Promise<CliArgs> {
|
|||
// and not return to main CLI logic
|
||||
if (
|
||||
result._.length > 0 &&
|
||||
(result._[0] === 'mcp' || result._[0] === 'extensions')
|
||||
(result._[0] === 'mcp' ||
|
||||
result._[0] === 'extensions' ||
|
||||
result._[0] === 'hooks')
|
||||
) {
|
||||
// MCP commands handle their own execution and process exit
|
||||
// MCP/Extensions/Hooks commands handle their own execution and process exit
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
|
|
@ -1025,6 +1030,7 @@ export async function loadCliConfig(
|
|||
output: {
|
||||
format: outputSettingsFormat,
|
||||
},
|
||||
hooks: settings.hooks,
|
||||
channel: argv.channel,
|
||||
// Precedence: explicit CLI flag > settings file > default(true).
|
||||
// NOTE: do NOT set a yargs default for `chat-recording`, otherwise argv will
|
||||
|
|
|
|||
|
|
@ -1177,6 +1177,118 @@ const SETTINGS_SCHEMA = {
|
|||
showInDialog: false,
|
||||
},
|
||||
|
||||
hooks: {
|
||||
type: 'object',
|
||||
label: 'Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: {},
|
||||
description:
|
||||
'Hook configurations for extending CLI behavior at various lifecycle points.',
|
||||
showInDialog: false,
|
||||
properties: {
|
||||
disabled: {
|
||||
type: 'array',
|
||||
label: 'Disabled Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [] as string[],
|
||||
description:
|
||||
'List of hook names to disable. Hooks in this list will not be executed.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.UNION,
|
||||
},
|
||||
PreToolUse: {
|
||||
type: 'array',
|
||||
label: 'PreTool Use Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute before tool invocations. Can validate, modify, or block tool calls.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
PostToolUse: {
|
||||
type: 'array',
|
||||
label: 'PostTool Use Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute after tool invocations. Can process results or trigger follow-up actions.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
BeforeAgent: {
|
||||
type: 'array',
|
||||
label: 'Before Agent Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute before agent processing. Can modify prompts or inject context.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
AfterAgent: {
|
||||
type: 'array',
|
||||
label: 'After Agent Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute after agent processing. Can post-process responses or log interactions.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
SessionStart: {
|
||||
type: 'array',
|
||||
label: 'Session Start Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute when a session starts. Can initialize state or load context.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
SessionEnd: {
|
||||
type: 'array',
|
||||
label: 'Session End Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute when a session ends. Can perform cleanup or persist session data.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
PreCompact: {
|
||||
type: 'array',
|
||||
label: 'PreCompact Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute before chat history compression. Can back up or analyze conversation before compression.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
Notification: {
|
||||
type: 'array',
|
||||
label: 'Notification Hooks',
|
||||
category: 'Advanced',
|
||||
requiresRestart: false,
|
||||
default: [],
|
||||
description:
|
||||
'Hooks that execute when notifications are triggered. Can handle alerts or status updates.',
|
||||
showInDialog: false,
|
||||
mergeStrategy: MergeStrategy.CONCAT,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
experimental: {
|
||||
type: 'object',
|
||||
label: 'Experimental',
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import { editorCommand } from '../ui/commands/editorCommand.js';
|
|||
import { exportCommand } from '../ui/commands/exportCommand.js';
|
||||
import { extensionsCommand } from '../ui/commands/extensionsCommand.js';
|
||||
import { helpCommand } from '../ui/commands/helpCommand.js';
|
||||
import { hooksCommand } from '../ui/commands/hooksCommand.js';
|
||||
import { ideCommand } from '../ui/commands/ideCommand.js';
|
||||
import { initCommand } from '../ui/commands/initCommand.js';
|
||||
import { languageCommand } from '../ui/commands/languageCommand.js';
|
||||
|
|
@ -71,6 +72,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
|||
exportCommand,
|
||||
extensionsCommand,
|
||||
helpCommand,
|
||||
hooksCommand,
|
||||
await ideCommand(),
|
||||
initCommand,
|
||||
languageCommand,
|
||||
|
|
|
|||
320
packages/cli/src/ui/commands/hooksCommand.ts
Normal file
320
packages/cli/src/ui/commands/hooksCommand.ts
Normal file
|
|
@ -0,0 +1,320 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
SlashCommand,
|
||||
SlashCommandActionReturn,
|
||||
CommandContext,
|
||||
MessageActionReturn,
|
||||
} from './types.js';
|
||||
import { CommandKind } from './types.js';
|
||||
import { t } from '../../i18n/index.js';
|
||||
import type { HookRegistryEntry } from '@qwen-code/qwen-code-core';
|
||||
|
||||
/**
|
||||
* Format hook source for display
|
||||
*/
|
||||
function formatHookSource(source: string): string {
|
||||
switch (source) {
|
||||
case 'project':
|
||||
return 'Project';
|
||||
case 'user':
|
||||
return 'User';
|
||||
case 'system':
|
||||
return 'System';
|
||||
case 'extensions':
|
||||
return 'Extension';
|
||||
default:
|
||||
return source;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format hook status for display
|
||||
*/
|
||||
function formatHookStatus(enabled: boolean): string {
|
||||
return enabled ? '✓ Enabled' : '✗ Disabled';
|
||||
}
|
||||
|
||||
const listCommand: SlashCommand = {
|
||||
name: 'list',
|
||||
get description() {
|
||||
return t('List all configured hooks');
|
||||
},
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
_args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
const { config } = context.services;
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t('Config not loaded.'),
|
||||
};
|
||||
}
|
||||
|
||||
const hookSystem = config.getHookSystem();
|
||||
if (!hookSystem) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: t(
|
||||
'Hooks are not enabled. Enable hooks in settings to use this feature.',
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
const registry = hookSystem.getRegistry();
|
||||
const allHooks = registry.getAllHooks();
|
||||
|
||||
if (allHooks.length === 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: t(
|
||||
'No hooks configured. Add hooks in your settings.json file.',
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
// Group hooks by event
|
||||
const hooksByEvent = new Map<string, HookRegistryEntry[]>();
|
||||
for (const hook of allHooks) {
|
||||
const eventName = hook.eventName;
|
||||
if (!hooksByEvent.has(eventName)) {
|
||||
hooksByEvent.set(eventName, []);
|
||||
}
|
||||
hooksByEvent.get(eventName)!.push(hook);
|
||||
}
|
||||
|
||||
let output = `**Configured Hooks (${allHooks.length} total)**\n\n`;
|
||||
|
||||
for (const [eventName, hooks] of hooksByEvent) {
|
||||
output += `### ${eventName}\n`;
|
||||
for (const hook of hooks) {
|
||||
const name = hook.config.name || hook.config.command || 'unnamed';
|
||||
const source = formatHookSource(hook.source);
|
||||
const status = formatHookStatus(hook.enabled);
|
||||
const matcher = hook.matcher ? ` (matcher: ${hook.matcher})` : '';
|
||||
output += `- **${name}** [${source}] ${status}${matcher}\n`;
|
||||
}
|
||||
output += '\n';
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: output,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const enableCommand: SlashCommand = {
|
||||
name: 'enable',
|
||||
get description() {
|
||||
return t('Enable a disabled hook');
|
||||
},
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
const hookName = args.trim();
|
||||
if (!hookName) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t(
|
||||
'Please specify a hook name. Usage: /hooks enable <hook-name>',
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
const { config } = context.services;
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t('Config not loaded.'),
|
||||
};
|
||||
}
|
||||
|
||||
const hookSystem = config.getHookSystem();
|
||||
if (!hookSystem) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t('Hooks are not enabled.'),
|
||||
};
|
||||
}
|
||||
|
||||
const registry = hookSystem.getRegistry();
|
||||
registry.setHookEnabled(hookName, true);
|
||||
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: t('Hook "{{name}}" has been enabled for this session.', {
|
||||
name: hookName,
|
||||
}),
|
||||
};
|
||||
},
|
||||
completion: async (context: CommandContext, partialArg: string) => {
|
||||
const { config } = context.services;
|
||||
if (!config) return [];
|
||||
|
||||
const hookSystem = config.getHookSystem();
|
||||
if (!hookSystem) return [];
|
||||
|
||||
const registry = hookSystem.getRegistry();
|
||||
const allHooks = registry.getAllHooks();
|
||||
|
||||
// Return disabled hooks for enable command
|
||||
return allHooks
|
||||
.filter((hook) => !hook.enabled)
|
||||
.map((hook) => hook.config.name || hook.config.command || '')
|
||||
.filter((name) => name && name.startsWith(partialArg));
|
||||
},
|
||||
};
|
||||
|
||||
const disableCommand: SlashCommand = {
|
||||
name: 'disable',
|
||||
get description() {
|
||||
return t('Disable an active hook');
|
||||
},
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<MessageActionReturn> => {
|
||||
const hookName = args.trim();
|
||||
if (!hookName) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t(
|
||||
'Please specify a hook name. Usage: /hooks disable <hook-name>',
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
const { config } = context.services;
|
||||
if (!config) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t('Config not loaded.'),
|
||||
};
|
||||
}
|
||||
|
||||
const hookSystem = config.getHookSystem();
|
||||
if (!hookSystem) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t('Hooks are not enabled.'),
|
||||
};
|
||||
}
|
||||
|
||||
const registry = hookSystem.getRegistry();
|
||||
registry.setHookEnabled(hookName, false);
|
||||
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: t('Hook "{{name}}" has been disabled for this session.', {
|
||||
name: hookName,
|
||||
}),
|
||||
};
|
||||
},
|
||||
completion: async (context: CommandContext, partialArg: string) => {
|
||||
const { config } = context.services;
|
||||
if (!config) return [];
|
||||
|
||||
const hookSystem = config.getHookSystem();
|
||||
if (!hookSystem) return [];
|
||||
|
||||
const registry = hookSystem.getRegistry();
|
||||
const allHooks = registry.getAllHooks();
|
||||
|
||||
// Return enabled hooks for disable command
|
||||
return allHooks
|
||||
.filter((hook) => hook.enabled)
|
||||
.map((hook) => hook.config.name || hook.config.command || '')
|
||||
.filter((name) => name && name.startsWith(partialArg));
|
||||
},
|
||||
};
|
||||
|
||||
export const hooksCommand: SlashCommand = {
|
||||
name: 'hooks',
|
||||
get description() {
|
||||
return t('Manage Qwen Code hooks');
|
||||
},
|
||||
kind: CommandKind.BUILT_IN,
|
||||
subCommands: [listCommand, enableCommand, disableCommand],
|
||||
action: async (
|
||||
context: CommandContext,
|
||||
args: string,
|
||||
): Promise<SlashCommandActionReturn> => {
|
||||
// If no subcommand provided, show list
|
||||
if (!args.trim()) {
|
||||
const result = await listCommand.action?.(context, '');
|
||||
return result ?? { type: 'message', messageType: 'info', content: '' };
|
||||
}
|
||||
|
||||
const [subcommand, ...rest] = args.trim().split(/\s+/);
|
||||
const subArgs = rest.join(' ');
|
||||
|
||||
let result: SlashCommandActionReturn | void;
|
||||
switch (subcommand.toLowerCase()) {
|
||||
case 'list':
|
||||
result = await listCommand.action?.(context, subArgs);
|
||||
break;
|
||||
case 'enable':
|
||||
result = await enableCommand.action?.(context, subArgs);
|
||||
break;
|
||||
case 'disable':
|
||||
result = await disableCommand.action?.(context, subArgs);
|
||||
break;
|
||||
default:
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: t(
|
||||
'Unknown subcommand: {{cmd}}. Available: list, enable, disable',
|
||||
{
|
||||
cmd: subcommand,
|
||||
},
|
||||
),
|
||||
};
|
||||
}
|
||||
return result ?? { type: 'message', messageType: 'info', content: '' };
|
||||
},
|
||||
completion: async (context: CommandContext, partialArg: string) => {
|
||||
const subcommands = ['list', 'enable', 'disable'];
|
||||
const parts = partialArg.split(/\s+/);
|
||||
|
||||
if (parts.length <= 1) {
|
||||
// Complete subcommand
|
||||
return subcommands.filter((cmd) => cmd.startsWith(partialArg));
|
||||
}
|
||||
|
||||
// Complete subcommand arguments
|
||||
const [subcommand, ...rest] = parts;
|
||||
const subArgs = rest.join(' ');
|
||||
|
||||
switch (subcommand.toLowerCase()) {
|
||||
case 'enable':
|
||||
return enableCommand.completion?.(context, subArgs) ?? [];
|
||||
case 'disable':
|
||||
return disableCommand.completion?.(context, subArgs) ?? [];
|
||||
default:
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -945,6 +945,15 @@ export const useGeminiStream = (
|
|||
clearRetryCountdown();
|
||||
}
|
||||
break;
|
||||
case ServerGeminiEventType.HookSystemMessage:
|
||||
// Display system message from hooks (e.g., Ralph Loop iteration info)
|
||||
// This is handled as a content event to show in the UI
|
||||
geminiMessageBuffer = handleContentEvent(
|
||||
event.value + '\n',
|
||||
geminiMessageBuffer,
|
||||
userMessageTimestamp,
|
||||
);
|
||||
break;
|
||||
default: {
|
||||
// enforces exhaustive switch-case
|
||||
const unreachable: never = event;
|
||||
|
|
|
|||
|
|
@ -84,6 +84,14 @@ import {
|
|||
ExtensionManager,
|
||||
type Extension,
|
||||
} from '../extension/extensionManager.js';
|
||||
import { HookSystem } from '../hooks/index.js';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type HookExecutionRequest,
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
|
||||
// Utils
|
||||
import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
|
||||
|
|
@ -378,6 +386,10 @@ export interface ConfigParameters {
|
|||
channel?: string;
|
||||
/** Model providers configuration grouped by authType */
|
||||
modelProvidersConfig?: ModelProvidersConfig;
|
||||
/** Enable hook system for lifecycle events */
|
||||
enableHooks?: boolean;
|
||||
/** Hooks configuration from settings */
|
||||
hooks?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
function normalizeConfigOutputFormat(
|
||||
|
|
@ -518,6 +530,11 @@ export class Config {
|
|||
private readonly eventEmitter?: EventEmitter;
|
||||
private readonly channel: string | undefined;
|
||||
private readonly defaultFileEncoding: FileEncodingType;
|
||||
private readonly enableHooks: boolean;
|
||||
private readonly hooks?: Record<string, unknown>;
|
||||
private hookSystem?: HookSystem;
|
||||
private messageBus?: MessageBus;
|
||||
private policyEngine?: PolicyEngine;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId ?? randomUUID();
|
||||
|
|
@ -672,6 +689,8 @@ export class Config {
|
|||
enabledExtensionOverrides: this.overrideExtensions,
|
||||
isWorkspaceTrusted: this.isTrustedFolder(),
|
||||
});
|
||||
this.enableHooks = params.enableHooks ?? true;
|
||||
this.hooks = params.hooks;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -695,6 +714,77 @@ export class Config {
|
|||
await this.extensionManager.refreshCache();
|
||||
this.debugLogger.debug('Extension manager initialized');
|
||||
|
||||
// Initialize hook system if enabled
|
||||
if (this.enableHooks) {
|
||||
this.hookSystem = new HookSystem(this);
|
||||
await this.hookSystem.initialize();
|
||||
this.debugLogger.debug('Hook system initialized');
|
||||
|
||||
// Initialize PolicyEngine and MessageBus for hook execution
|
||||
this.policyEngine = new PolicyEngine();
|
||||
this.messageBus = new MessageBus(this.policyEngine);
|
||||
|
||||
// Subscribe to HOOK_EXECUTION_REQUEST to execute hooks
|
||||
this.messageBus.subscribe<HookExecutionRequest>(
|
||||
MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
async (request: HookExecutionRequest) => {
|
||||
try {
|
||||
const hookSystem = this.hookSystem;
|
||||
if (!hookSystem) {
|
||||
this.messageBus?.publish({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
success: false,
|
||||
error: new Error('Hook system not initialized'),
|
||||
} as HookExecutionResponse);
|
||||
return;
|
||||
}
|
||||
|
||||
// Execute the appropriate hook based on eventName
|
||||
let result;
|
||||
const input = request.input || {};
|
||||
switch (request.eventName) {
|
||||
case 'UserPromptSubmit':
|
||||
result = await hookSystem.fireUserPromptSubmitEvent(
|
||||
(input['prompt'] as string) || '',
|
||||
);
|
||||
break;
|
||||
case 'Stop':
|
||||
result = await hookSystem.fireStopEvent(
|
||||
(input['prompt'] as string) || '',
|
||||
(input['prompt_response'] as string) || '',
|
||||
(input['stop_hook_active'] as boolean) || false,
|
||||
);
|
||||
break;
|
||||
default:
|
||||
this.debugLogger.warn(
|
||||
`Unknown hook event: ${request.eventName}`,
|
||||
);
|
||||
result = undefined;
|
||||
}
|
||||
|
||||
// Send response
|
||||
this.messageBus?.publish({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
success: true,
|
||||
output: result,
|
||||
} as HookExecutionResponse);
|
||||
} catch (error) {
|
||||
this.debugLogger.warn(`Hook execution failed: ${error}`);
|
||||
this.messageBus?.publish({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
success: false,
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
} as HookExecutionResponse);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
this.debugLogger.debug('MessageBus initialized with hook subscription');
|
||||
}
|
||||
|
||||
this.subagentManager = new SubagentManager(this);
|
||||
this.skillManager = new SkillManager(this);
|
||||
await this.skillManager.startWatching();
|
||||
|
|
@ -1374,6 +1464,81 @@ export class Config {
|
|||
return this.extensionManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the hook system instance if hooks are enabled.
|
||||
* Returns undefined if hooks are not enabled.
|
||||
*/
|
||||
getHookSystem(): HookSystem | undefined {
|
||||
return this.hookSystem;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if hooks are enabled.
|
||||
*/
|
||||
getEnableHooks(): boolean {
|
||||
return this.enableHooks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message bus instance.
|
||||
* Returns undefined if not set.
|
||||
*/
|
||||
getMessageBus(): MessageBus | undefined {
|
||||
return this.messageBus;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the message bus instance.
|
||||
* This is called by the CLI layer to inject the MessageBus.
|
||||
*/
|
||||
setMessageBus(messageBus: MessageBus): void {
|
||||
this.messageBus = messageBus;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the policy engine instance.
|
||||
* Returns undefined if not set.
|
||||
*/
|
||||
getPolicyEngine(): PolicyEngine | undefined {
|
||||
return this.policyEngine;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the policy engine instance.
|
||||
* This is called by the CLI layer to inject the PolicyEngine.
|
||||
*/
|
||||
setPolicyEngine(policyEngine: PolicyEngine): void {
|
||||
this.policyEngine = policyEngine;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the list of disabled hook names.
|
||||
* This is used by the HookRegistry to filter out disabled hooks.
|
||||
*/
|
||||
getDisabledHooks(): string[] {
|
||||
// This will be populated from settings by the CLI layer
|
||||
// The core Config doesn't have direct access to settings
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get project-level hooks configuration.
|
||||
* This is used by the HookRegistry to load project-specific hooks.
|
||||
*/
|
||||
getProjectHooks(): Record<string, unknown> | undefined {
|
||||
// This will be populated from settings by the CLI layer
|
||||
// The core Config doesn't have direct access to settings
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all hooks configuration (merged from all sources).
|
||||
* This is used by the HookRegistry to load hooks.
|
||||
*/
|
||||
getHooks(): Record<string, unknown> | undefined {
|
||||
return this.hooks;
|
||||
}
|
||||
|
||||
getExtensions(): Extension[] {
|
||||
const extensions = this.extensionManager.getLoadedExtensions();
|
||||
if (this.overrideExtensions) {
|
||||
|
|
@ -1614,6 +1779,21 @@ export class Config {
|
|||
return this.chatRecordingService;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the transcript file path for the current session.
|
||||
* This is the path to the JSONL file where the conversation is recorded.
|
||||
* Returns empty string if chat recording is disabled.
|
||||
*/
|
||||
getTranscriptPath(): string {
|
||||
if (!this.chatRecordingEnabled) {
|
||||
return '';
|
||||
}
|
||||
const projectDir = this.storage.getProjectDir();
|
||||
const sessionId = this.getSessionId();
|
||||
const safeFilename = `${sessionId}.jsonl`;
|
||||
return path.join(projectDir, 'chats', safeFilename);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets or creates a SessionService for managing chat sessions.
|
||||
*/
|
||||
|
|
|
|||
206
packages/core/src/confirmation-bus/message-bus.ts
Normal file
206
packages/core/src/confirmation-bus/message-bus.ts
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import { PolicyDecision, getHookSource } from '../policy/types.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type Message,
|
||||
type HookExecutionRequest,
|
||||
type HookPolicyDecision,
|
||||
} from './types.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const debugLogger = createDebugLogger('TRUSTED_HOOKS');
|
||||
|
||||
export class MessageBus extends EventEmitter {
|
||||
constructor(
|
||||
private readonly policyEngine: PolicyEngine,
|
||||
private readonly debug = false,
|
||||
) {
|
||||
super();
|
||||
this.debug = debug;
|
||||
}
|
||||
|
||||
private isValidMessage(message: Message): boolean {
|
||||
if (!message || !message.type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST &&
|
||||
!('correlationId' in message)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private emitMessage(message: Message): void {
|
||||
this.emit(message.type, message);
|
||||
}
|
||||
|
||||
async publish(message: Message): Promise<void> {
|
||||
if (this.debug) {
|
||||
debugLogger.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
|
||||
}
|
||||
try {
|
||||
if (!this.isValidMessage(message)) {
|
||||
throw new Error(
|
||||
`Invalid message structure: ${safeJsonStringify(message)}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
|
||||
const { decision } = await this.policyEngine.check(
|
||||
message.toolCall,
|
||||
message.serverName,
|
||||
);
|
||||
|
||||
switch (decision) {
|
||||
case PolicyDecision.ALLOW:
|
||||
// Directly emit the response instead of recursive publish
|
||||
this.emitMessage({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: message.correlationId,
|
||||
confirmed: true,
|
||||
});
|
||||
break;
|
||||
case PolicyDecision.DENY:
|
||||
// Emit both rejection and response messages
|
||||
this.emitMessage({
|
||||
type: MessageBusType.TOOL_POLICY_REJECTION,
|
||||
toolCall: message.toolCall,
|
||||
});
|
||||
this.emitMessage({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: message.correlationId,
|
||||
confirmed: false,
|
||||
});
|
||||
break;
|
||||
case PolicyDecision.ASK_USER:
|
||||
// Pass through to UI for user confirmation if any listeners exist.
|
||||
// If no listeners are registered (e.g., headless/ACP flows),
|
||||
// immediately request user confirmation to avoid long timeouts.
|
||||
if (
|
||||
this.listenerCount(MessageBusType.TOOL_CONFIRMATION_REQUEST) > 0
|
||||
) {
|
||||
this.emitMessage(message);
|
||||
} else {
|
||||
this.emitMessage({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: message.correlationId,
|
||||
confirmed: false,
|
||||
requiresUserConfirmation: true,
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown policy decision: ${decision}`);
|
||||
}
|
||||
} else if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) {
|
||||
// Handle hook execution requests through policy evaluation
|
||||
const hookRequest = message as HookExecutionRequest;
|
||||
const decision = await this.policyEngine.checkHook(hookRequest);
|
||||
|
||||
// Map decision to allow/deny for observability (ASK_USER treated as deny for hooks)
|
||||
const effectiveDecision =
|
||||
decision === PolicyDecision.ALLOW ? 'allow' : 'deny';
|
||||
|
||||
// Emit policy decision for observability
|
||||
this.emitMessage({
|
||||
type: MessageBusType.HOOK_POLICY_DECISION,
|
||||
eventName: hookRequest.eventName,
|
||||
hookSource: getHookSource(hookRequest.input),
|
||||
decision: effectiveDecision,
|
||||
reason:
|
||||
decision !== PolicyDecision.ALLOW
|
||||
? 'Hook execution denied by policy'
|
||||
: undefined,
|
||||
} as HookPolicyDecision);
|
||||
|
||||
// If allowed, emit the request for hook system to handle
|
||||
if (decision === PolicyDecision.ALLOW) {
|
||||
this.emitMessage(message);
|
||||
} else {
|
||||
// If denied or ASK_USER, emit error response (hooks don't support interactive confirmation)
|
||||
this.emitMessage({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: hookRequest.correlationId,
|
||||
success: false,
|
||||
error: new Error('Hook execution denied by policy'),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// For all other message types, just emit them
|
||||
this.emitMessage(message);
|
||||
}
|
||||
} catch (error) {
|
||||
this.emit('error', error);
|
||||
}
|
||||
}
|
||||
|
||||
subscribe<T extends Message>(
|
||||
type: T['type'],
|
||||
listener: (message: T) => void,
|
||||
): void {
|
||||
this.on(type, listener);
|
||||
}
|
||||
|
||||
unsubscribe<T extends Message>(
|
||||
type: T['type'],
|
||||
listener: (message: T) => void,
|
||||
): void {
|
||||
this.off(type, listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Request-response pattern: Publish a message and wait for a correlated response
|
||||
* This enables synchronous-style communication over the async MessageBus
|
||||
* The correlation ID is generated internally and added to the request
|
||||
*/
|
||||
async request<TRequest extends Message, TResponse extends Message>(
|
||||
request: Omit<TRequest, 'correlationId'>,
|
||||
responseType: TResponse['type'],
|
||||
timeoutMs: number = 60000,
|
||||
): Promise<TResponse> {
|
||||
const correlationId = randomUUID();
|
||||
|
||||
return new Promise<TResponse>((resolve, reject) => {
|
||||
const timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
reject(new Error(`Request timed out waiting for ${responseType}`));
|
||||
}, timeoutMs);
|
||||
|
||||
const cleanup = () => {
|
||||
clearTimeout(timeoutId);
|
||||
this.unsubscribe(responseType, responseHandler);
|
||||
};
|
||||
|
||||
const responseHandler = (response: TResponse) => {
|
||||
// Check if this response matches our request
|
||||
if (
|
||||
'correlationId' in response &&
|
||||
response.correlationId === correlationId
|
||||
) {
|
||||
cleanup();
|
||||
resolve(response);
|
||||
}
|
||||
};
|
||||
|
||||
// Subscribe to responses
|
||||
this.subscribe<TResponse>(responseType, responseHandler);
|
||||
|
||||
// Publish the request with correlation ID
|
||||
|
||||
this.publish({ ...request, correlationId } as TRequest);
|
||||
});
|
||||
}
|
||||
}
|
||||
212
packages/core/src/confirmation-bus/types.ts
Normal file
212
packages/core/src/confirmation-bus/types.ts
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type FunctionCall } from '@google/genai';
|
||||
import type {
|
||||
ToolConfirmationOutcome,
|
||||
ToolConfirmationPayload,
|
||||
} from '../tools/tools.js';
|
||||
import type { ToolCall } from '../core/coreToolScheduler.js';
|
||||
|
||||
export enum MessageBusType {
|
||||
TOOL_CONFIRMATION_REQUEST = 'tool-confirmation-request',
|
||||
TOOL_CONFIRMATION_RESPONSE = 'tool-confirmation-response',
|
||||
TOOL_POLICY_REJECTION = 'tool-policy-rejection',
|
||||
TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
|
||||
TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
|
||||
UPDATE_POLICY = 'update-policy',
|
||||
TOOL_CALLS_UPDATE = 'tool-calls-update',
|
||||
ASK_USER_REQUEST = 'ask-user-request',
|
||||
ASK_USER_RESPONSE = 'ask-user-response',
|
||||
HOOK_EXECUTION_REQUEST = 'hook-execution-request',
|
||||
HOOK_EXECUTION_RESPONSE = 'hook-execution-response',
|
||||
HOOK_POLICY_DECISION = 'hook-policy-decision',
|
||||
}
|
||||
|
||||
export interface ToolCallsUpdateMessage {
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE;
|
||||
toolCalls: ToolCall[];
|
||||
schedulerId: string;
|
||||
}
|
||||
|
||||
export interface ToolConfirmationRequest {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST;
|
||||
toolCall: FunctionCall;
|
||||
correlationId: string;
|
||||
serverName?: string;
|
||||
/**
|
||||
* Optional rich details for the confirmation UI (diffs, counts, etc.)
|
||||
*/
|
||||
details?: SerializableConfirmationDetails;
|
||||
}
|
||||
|
||||
export interface ToolConfirmationResponse {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE;
|
||||
correlationId: string;
|
||||
confirmed: boolean;
|
||||
/**
|
||||
* The specific outcome selected by the user.
|
||||
*
|
||||
* TODO: Make required after migration.
|
||||
*/
|
||||
outcome?: ToolConfirmationOutcome;
|
||||
/**
|
||||
* Optional payload (e.g., modified content for 'modify_with_editor').
|
||||
*/
|
||||
payload?: ToolConfirmationPayload;
|
||||
/**
|
||||
* When true, indicates that policy decision was ASK_USER and the tool should
|
||||
* show its legacy confirmation UI instead of auto-proceeding.
|
||||
*/
|
||||
requiresUserConfirmation?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data-only versions of ToolCallConfirmationDetails for bus transmission.
|
||||
*/
|
||||
export type SerializableConfirmationDetails =
|
||||
| {
|
||||
type: 'info';
|
||||
title: string;
|
||||
prompt: string;
|
||||
urls?: string[];
|
||||
}
|
||||
| {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
fileName: string;
|
||||
filePath: string;
|
||||
fileDiff: string;
|
||||
originalContent: string | null;
|
||||
newContent: string;
|
||||
isModifying?: boolean;
|
||||
}
|
||||
| {
|
||||
type: 'exec';
|
||||
title: string;
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
rootCommands: string[];
|
||||
commands?: string[];
|
||||
}
|
||||
| {
|
||||
type: 'mcp';
|
||||
title: string;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDisplayName: string;
|
||||
}
|
||||
| {
|
||||
type: 'ask_user';
|
||||
title: string;
|
||||
questions: Question[];
|
||||
}
|
||||
| {
|
||||
type: 'exit_plan_mode';
|
||||
title: string;
|
||||
planPath: string;
|
||||
};
|
||||
|
||||
export interface UpdatePolicy {
|
||||
type: MessageBusType.UPDATE_POLICY;
|
||||
toolName: string;
|
||||
persist?: boolean;
|
||||
argsPattern?: string;
|
||||
commandPrefix?: string | string[];
|
||||
mcpName?: string;
|
||||
}
|
||||
|
||||
export interface ToolPolicyRejection {
|
||||
type: MessageBusType.TOOL_POLICY_REJECTION;
|
||||
toolCall: FunctionCall;
|
||||
}
|
||||
|
||||
export interface ToolExecutionSuccess<T = unknown> {
|
||||
type: MessageBusType.TOOL_EXECUTION_SUCCESS;
|
||||
toolCall: FunctionCall;
|
||||
result: T;
|
||||
}
|
||||
|
||||
export interface ToolExecutionFailure<E = Error> {
|
||||
type: MessageBusType.TOOL_EXECUTION_FAILURE;
|
||||
toolCall: FunctionCall;
|
||||
error: E;
|
||||
}
|
||||
|
||||
export interface HookExecutionRequest {
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST;
|
||||
eventName: string;
|
||||
input: Record<string, unknown>;
|
||||
correlationId: string;
|
||||
}
|
||||
|
||||
export interface HookExecutionResponse {
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE;
|
||||
correlationId: string;
|
||||
success: boolean;
|
||||
output?: Record<string, unknown>;
|
||||
error?: Error;
|
||||
}
|
||||
|
||||
export interface HookPolicyDecision {
|
||||
type: MessageBusType.HOOK_POLICY_DECISION;
|
||||
eventName: string;
|
||||
hookSource: 'project' | 'user' | 'system' | 'extension';
|
||||
decision: 'allow' | 'deny';
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
export interface QuestionOption {
|
||||
label: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export enum QuestionType {
|
||||
CHOICE = 'choice',
|
||||
TEXT = 'text',
|
||||
YESNO = 'yesno',
|
||||
}
|
||||
|
||||
export interface Question {
|
||||
question: string;
|
||||
header: string;
|
||||
/** Question type: 'choice' renders selectable options, 'text' renders free-form input, 'yesno' renders a binary Yes/No choice. */
|
||||
type: QuestionType;
|
||||
/** Selectable choices. REQUIRED when type='choice'. IGNORED for 'text' and 'yesno'. */
|
||||
options?: QuestionOption[];
|
||||
/** Allow multiple selections. Only applies when type='choice'. */
|
||||
multiSelect?: boolean;
|
||||
/** Placeholder hint text. For type='text', shown in the input field. For type='choice', shown in the "Other" custom input. */
|
||||
placeholder?: string;
|
||||
}
|
||||
|
||||
export interface AskUserRequest {
|
||||
type: MessageBusType.ASK_USER_REQUEST;
|
||||
questions: Question[];
|
||||
correlationId: string;
|
||||
}
|
||||
|
||||
export interface AskUserResponse {
|
||||
type: MessageBusType.ASK_USER_RESPONSE;
|
||||
correlationId: string;
|
||||
answers: { [questionIndex: string]: string };
|
||||
/** When true, indicates the user cancelled the dialog without submitting answers */
|
||||
cancelled?: boolean;
|
||||
}
|
||||
|
||||
export type Message =
|
||||
| ToolConfirmationRequest
|
||||
| ToolConfirmationResponse
|
||||
| ToolPolicyRejection
|
||||
| ToolExecutionSuccess
|
||||
| ToolExecutionFailure
|
||||
| UpdatePolicy
|
||||
| AskUserRequest
|
||||
| AskUserResponse
|
||||
| ToolCallsUpdateMessage
|
||||
| HookExecutionRequest
|
||||
| HookExecutionResponse
|
||||
| HookPolicyDecision;
|
||||
|
|
@ -69,6 +69,12 @@ import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
|||
import { flatMapTextParts } from '../utils/partUtils.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
|
||||
// Hook triggers
|
||||
import {
|
||||
fireUserPromptSubmitHook,
|
||||
fireStopHook,
|
||||
} from './clientHookTriggers.js';
|
||||
|
||||
// IDE integration
|
||||
import { ideContextStore } from '../ide/ideContext.js';
|
||||
import { type File, type IdeContext } from '../ide/types.js';
|
||||
|
|
@ -407,6 +413,35 @@ export class GeminiClient {
|
|||
options?: { isContinuation: boolean },
|
||||
turns: number = MAX_TURNS,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await fireUserPromptSubmitHook(messageBus, request);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
yield {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
},
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Add additional context from hooks to the request
|
||||
const additionalContext = hookOutput?.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
const requestArray = Array.isArray(request) ? request : [request];
|
||||
request = [...requestArray, { text: additionalContext }];
|
||||
}
|
||||
}
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.lastPromptId = prompt_id;
|
||||
|
|
@ -536,6 +571,50 @@ export class GeminiClient {
|
|||
return turn;
|
||||
}
|
||||
}
|
||||
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
|
||||
// This must be done before any early returns to ensure hooks are always triggered
|
||||
if (hooksEnabled && messageBus && !turn.pendingToolCalls.length) {
|
||||
// Get response text from the chat history
|
||||
const history = this.getHistory();
|
||||
const lastModelMessage = history
|
||||
.filter((msg) => msg.role === 'model')
|
||||
.pop();
|
||||
const responseText =
|
||||
lastModelMessage?.parts
|
||||
?.filter((p): p is { text: string } => 'text' in p)
|
||||
.map((p) => p.text)
|
||||
.join('') || '[no response text]';
|
||||
|
||||
const hookOutput = await fireStopHook(messageBus, request, responseText);
|
||||
|
||||
// For AfterAgent hooks, blocking/stop execution should force continuation (like Stop Hook)
|
||||
// This enables Ralph Loop functionality where the hook can:
|
||||
// 1. Return {"decision": "block", "reason": "<prompt>"} to continue with a new prompt
|
||||
// 2. Optionally include "systemMessage" to display a status message
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
// Emit system message if provided (e.g., "🔄 Ralph iteration 5")
|
||||
if (hookOutput.systemMessage) {
|
||||
yield {
|
||||
type: GeminiEventType.HookSystemMessage,
|
||||
value: hookOutput.systemMessage,
|
||||
};
|
||||
}
|
||||
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
return yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
{ isContinuation: true },
|
||||
boundedTurns - 1,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
if (this.config.getSkipNextSpeakerCheck()) {
|
||||
return turn;
|
||||
|
|
@ -557,9 +636,9 @@ export class GeminiClient {
|
|||
);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, but the final
|
||||
// turn object will be from the top-level call.
|
||||
yield* this.sendMessageStream(
|
||||
// This recursive call's events will be yielded out, and the final
|
||||
// turn object from the recursive call will be returned.
|
||||
return yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
|
|
@ -568,6 +647,7 @@ export class GeminiClient {
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
return turn;
|
||||
}
|
||||
|
||||
|
|
|
|||
107
packages/core/src/core/clientHookTriggers.ts
Normal file
107
packages/core/src/core/clientHookTriggers.ts
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { PartListUnion } from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type HookExecutionRequest,
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { createHookOutput, type DefaultHookOutput } from '../hooks/types.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const debugLogger = createDebugLogger('HOOK_TRIGGERS');
|
||||
|
||||
/**
|
||||
* Fires the UserPromptSubmit hook and returns the hook output.
|
||||
* This should be called before processing a user prompt.
|
||||
*
|
||||
* The caller can use the returned DefaultHookOutput methods:
|
||||
* - isBlockingDecision() / shouldStopExecution() to check if blocked
|
||||
* - getEffectiveReason() to get the blocking reason
|
||||
* - getAdditionalContext() to get additional context to add
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param request The user's request (prompt)
|
||||
* @returns The hook output, or undefined if no hook was executed or on error
|
||||
*/
|
||||
export async function fireUserPromptSubmitHook(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const promptText = partToString(request);
|
||||
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'UserPromptSubmit',
|
||||
input: {
|
||||
prompt: promptText,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
return response.output
|
||||
? createHookOutput('UserPromptSubmit', response.output)
|
||||
: undefined;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`UserPromptSubmit hook failed: ${error}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fires the Stop hook and returns the hook output.
|
||||
* This should be called after the agent has generated a response.
|
||||
*
|
||||
* The caller can use the returned DefaultHookOutput methods:
|
||||
* - isBlockingDecision() / shouldStopExecution() to check if continuation is requested
|
||||
* - getEffectiveReason() to get the continuation reason
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param request The original user's request (prompt)
|
||||
* @param responseText The agent's response text
|
||||
* @returns The hook output, or undefined if no hook was executed or on error
|
||||
*/
|
||||
export async function fireStopHook(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
responseText: string,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const promptText = partToString(request);
|
||||
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'Stop',
|
||||
input: {
|
||||
prompt: promptText,
|
||||
prompt_response: responseText,
|
||||
stop_hook_active: false,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
return response.output
|
||||
? createHookOutput('Stop', response.output)
|
||||
: undefined;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`Stop hook failed: ${error}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
|
@ -64,6 +64,7 @@ export enum GeminiEventType {
|
|||
LoopDetected = 'loop_detected',
|
||||
Citation = 'citation',
|
||||
Retry = 'retry',
|
||||
HookSystemMessage = 'hook_system_message',
|
||||
}
|
||||
|
||||
export type ServerGeminiRetryEvent = {
|
||||
|
|
@ -200,6 +201,11 @@ export type ServerGeminiCitationEvent = {
|
|||
value: string;
|
||||
};
|
||||
|
||||
export type ServerGeminiHookSystemMessageEvent = {
|
||||
type: GeminiEventType.HookSystemMessage;
|
||||
value: string;
|
||||
};
|
||||
|
||||
// The original union type, now composed of the individual types
|
||||
export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiChatCompressedEvent
|
||||
|
|
@ -207,6 +213,7 @@ export type ServerGeminiStreamEvent =
|
|||
| ServerGeminiContentEvent
|
||||
| ServerGeminiErrorEvent
|
||||
| ServerGeminiFinishedEvent
|
||||
| ServerGeminiHookSystemMessageEvent
|
||||
| ServerGeminiLoopDetectedEvent
|
||||
| ServerGeminiMaxSessionTurnsEvent
|
||||
| ServerGeminiThoughtEvent
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ export interface Extension {
|
|||
commands?: string[];
|
||||
skills?: SkillConfig[];
|
||||
agents?: SubagentConfig[];
|
||||
hooks?: Record<string, unknown[]>;
|
||||
}
|
||||
|
||||
export interface ExtensionConfig {
|
||||
|
|
|
|||
227
packages/core/src/hooks/hookAggregator.ts
Normal file
227
packages/core/src/hooks/hookAggregator.ts
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
HookEventName,
|
||||
DefaultHookOutput,
|
||||
PreToolUseHookOutput,
|
||||
StopHookOutput,
|
||||
} from './types.js';
|
||||
import type { HookOutput, HookExecutionResult } from './types.js';
|
||||
|
||||
/**
|
||||
* Aggregated result from multiple hook executions
|
||||
*/
|
||||
export interface AggregatedHookResult {
|
||||
success: boolean;
|
||||
allOutputs: HookOutput[];
|
||||
errors: Error[];
|
||||
totalDuration: number;
|
||||
finalOutput?: HookOutput;
|
||||
}
|
||||
|
||||
/**
|
||||
* HookAggregator merges multiple hook outputs using event-specific rules.
|
||||
*
|
||||
* Different events have different merging strategies:
|
||||
* - PreToolUse/PostToolUse: OR logic for decisions, concatenation for messages
|
||||
*/
|
||||
export class HookAggregator {
|
||||
/**
|
||||
* Aggregate results from multiple hook executions
|
||||
*/
|
||||
aggregateResults(
|
||||
results: HookExecutionResult[],
|
||||
eventName: HookEventName,
|
||||
): AggregatedHookResult {
|
||||
const allOutputs: HookOutput[] = [];
|
||||
const errors: Error[] = [];
|
||||
let totalDuration = 0;
|
||||
|
||||
for (const result of results) {
|
||||
totalDuration += result.duration;
|
||||
|
||||
if (!result.success && result.error) {
|
||||
errors.push(result.error);
|
||||
}
|
||||
|
||||
if (result.output) {
|
||||
allOutputs.push(result.output);
|
||||
}
|
||||
}
|
||||
|
||||
const success = errors.length === 0;
|
||||
const finalOutput = this.mergeOutputs(allOutputs, eventName);
|
||||
|
||||
return {
|
||||
success,
|
||||
allOutputs,
|
||||
errors,
|
||||
totalDuration,
|
||||
finalOutput,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge multiple hook outputs based on event type
|
||||
*/
|
||||
private mergeOutputs(
|
||||
outputs: HookOutput[],
|
||||
eventName: HookEventName,
|
||||
): HookOutput | undefined {
|
||||
if (outputs.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (outputs.length === 1) {
|
||||
return this.createSpecificHookOutput(outputs[0], eventName);
|
||||
}
|
||||
|
||||
let merged: HookOutput;
|
||||
|
||||
switch (eventName) {
|
||||
case HookEventName.PreToolUse:
|
||||
case HookEventName.PostToolUse:
|
||||
merged = this.mergeWithOrLogic(outputs);
|
||||
break;
|
||||
|
||||
default:
|
||||
merged = this.mergeSimple(outputs);
|
||||
}
|
||||
|
||||
return this.createSpecificHookOutput(merged, eventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge outputs using OR logic for decisions and concatenation for messages.
|
||||
*
|
||||
* Rules:
|
||||
* - Any "block" or "deny" decision results in blocking (most restrictive wins)
|
||||
* - Reasons are concatenated with newlines
|
||||
* - continue=false takes precedence over continue=true
|
||||
* - Additional context is concatenated
|
||||
*/
|
||||
private mergeWithOrLogic(outputs: HookOutput[]): HookOutput {
|
||||
const merged: HookOutput = {};
|
||||
const reasons: string[] = [];
|
||||
const additionalContexts: string[] = [];
|
||||
let hasBlock = false;
|
||||
let hasContinueFalse = false;
|
||||
let stopReason: string | undefined;
|
||||
|
||||
for (const output of outputs) {
|
||||
// Check for blocking decisions
|
||||
if (output.decision === 'block' || output.decision === 'deny') {
|
||||
hasBlock = true;
|
||||
}
|
||||
|
||||
// Collect reasons
|
||||
if (output.reason) {
|
||||
reasons.push(output.reason);
|
||||
}
|
||||
|
||||
// Check continue flag
|
||||
if (output.continue === false) {
|
||||
hasContinueFalse = true;
|
||||
if (output.stopReason) {
|
||||
stopReason = output.stopReason;
|
||||
}
|
||||
}
|
||||
|
||||
// Extract additional context
|
||||
this.extractAdditionalContext(output, additionalContexts);
|
||||
|
||||
// Copy other fields (later values win for simple fields)
|
||||
if (output.suppressOutput !== undefined) {
|
||||
merged.suppressOutput = output.suppressOutput;
|
||||
}
|
||||
if (output.systemMessage !== undefined) {
|
||||
merged.systemMessage = output.systemMessage;
|
||||
}
|
||||
}
|
||||
|
||||
// Set merged decision
|
||||
if (hasBlock) {
|
||||
merged.decision = 'block';
|
||||
} else if (outputs.some((o) => o.decision === 'allow')) {
|
||||
merged.decision = 'allow';
|
||||
}
|
||||
|
||||
// Set merged reason
|
||||
if (reasons.length > 0) {
|
||||
merged.reason = reasons.join('\n');
|
||||
}
|
||||
|
||||
// Set continue flag
|
||||
if (hasContinueFalse) {
|
||||
merged.continue = false;
|
||||
if (stopReason) {
|
||||
merged.stopReason = stopReason;
|
||||
}
|
||||
}
|
||||
|
||||
// Set additional context if any
|
||||
if (additionalContexts.length > 0) {
|
||||
merged.hookSpecificOutput = {
|
||||
...merged.hookSpecificOutput,
|
||||
additionalContext: additionalContexts.join('\n'),
|
||||
};
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple merge for events without special logic
|
||||
*/
|
||||
private mergeSimple(outputs: HookOutput[]): HookOutput {
|
||||
let merged: HookOutput = {};
|
||||
|
||||
for (const output of outputs) {
|
||||
merged = { ...merged, ...output };
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the appropriate specific hook output class based on event type
|
||||
*/
|
||||
private createSpecificHookOutput(
|
||||
output: HookOutput,
|
||||
eventName: HookEventName,
|
||||
): DefaultHookOutput {
|
||||
switch (eventName) {
|
||||
case HookEventName.PreToolUse:
|
||||
return new PreToolUseHookOutput(output);
|
||||
case HookEventName.Stop:
|
||||
return new StopHookOutput(output);
|
||||
default:
|
||||
return new DefaultHookOutput(output);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract additional context from hook-specific outputs
|
||||
*/
|
||||
private extractAdditionalContext(
|
||||
output: HookOutput,
|
||||
contexts: string[],
|
||||
): void {
|
||||
const specific = output.hookSpecificOutput;
|
||||
if (!specific) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract additionalContext from various hook types
|
||||
if (
|
||||
'additionalContext' in specific &&
|
||||
typeof specific['additionalContext'] === 'string'
|
||||
) {
|
||||
contexts.push(specific['additionalContext']);
|
||||
}
|
||||
}
|
||||
}
|
||||
401
packages/core/src/hooks/hookEventHandler.ts
Normal file
401
packages/core/src/hooks/hookEventHandler.ts
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { HookPlanner, HookEventContext } from './hookPlanner.js';
|
||||
import type { HookRunner } from './hookRunner.js';
|
||||
import type { HookAggregator, AggregatedHookResult } from './hookAggregator.js';
|
||||
import { HookEventName } from './types.js';
|
||||
import type {
|
||||
HookConfig,
|
||||
HookInput,
|
||||
HookExecutionResult,
|
||||
PreToolUseInput,
|
||||
PostToolUseInput,
|
||||
UserPromptSubmitInput,
|
||||
NotificationInput,
|
||||
StopInput,
|
||||
SessionStartInput,
|
||||
SessionEndInput,
|
||||
PreCompactInput,
|
||||
NotificationType,
|
||||
SessionStartSource,
|
||||
SessionEndReason,
|
||||
PreCompactTrigger,
|
||||
McpToolContext,
|
||||
} from './types.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const debugLogger = createDebugLogger('TRUSTED_HOOKS');
|
||||
|
||||
/**
|
||||
* Hook event bus that coordinates hook execution across the system
|
||||
*/
|
||||
export class HookEventHandler {
|
||||
private readonly config: Config;
|
||||
private readonly hookPlanner: HookPlanner;
|
||||
private readonly hookRunner: HookRunner;
|
||||
private readonly hookAggregator: HookAggregator;
|
||||
|
||||
/**
|
||||
* Track reported failures to suppress duplicate warnings during streaming.
|
||||
* Uses a WeakMap with the original request object as a key to ensure
|
||||
* failures are only reported once per logical model interaction.
|
||||
*/
|
||||
private readonly reportedFailures = new WeakMap<object, Set<string>>();
|
||||
|
||||
constructor(
|
||||
config: Config,
|
||||
hookPlanner: HookPlanner,
|
||||
hookRunner: HookRunner,
|
||||
hookAggregator: HookAggregator,
|
||||
) {
|
||||
this.config = config;
|
||||
this.hookPlanner = hookPlanner;
|
||||
this.hookRunner = hookRunner;
|
||||
this.hookAggregator = hookAggregator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a PreToolUse event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async firePreToolUseEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: PreToolUseInput = {
|
||||
...this.createBaseInput(HookEventName.PreToolUse),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
return this.executeHooks(HookEventName.PreToolUse, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a PostToolUse event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async firePostToolUseEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
toolResponse: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: PostToolUseInput = {
|
||||
...this.createBaseInput(HookEventName.PostToolUse),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
return this.executeHooks(HookEventName.PostToolUse, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a UserPromptSubmit event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireUserPromptSubmitEvent(
|
||||
prompt: string,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: UserPromptSubmitInput = {
|
||||
...this.createBaseInput(HookEventName.UserPromptSubmit),
|
||||
prompt,
|
||||
};
|
||||
|
||||
return this.executeHooks(HookEventName.UserPromptSubmit, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a Notification event
|
||||
*/
|
||||
async fireNotificationEvent(
|
||||
type: NotificationType,
|
||||
message: string,
|
||||
details: Record<string, unknown>,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: NotificationInput = {
|
||||
...this.createBaseInput(HookEventName.Notification),
|
||||
notification_type: type,
|
||||
message,
|
||||
details,
|
||||
};
|
||||
|
||||
return this.executeHooks(HookEventName.Notification, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a Stop event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireStopEvent(
|
||||
prompt: string,
|
||||
promptResponse: string,
|
||||
stopHookActive: boolean = false,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: StopInput = {
|
||||
...this.createBaseInput(HookEventName.Stop),
|
||||
prompt,
|
||||
prompt_response: promptResponse,
|
||||
stop_hook_active: stopHookActive,
|
||||
};
|
||||
|
||||
return this.executeHooks(HookEventName.Stop, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a SessionStart event
|
||||
*/
|
||||
async fireSessionStartEvent(
|
||||
source: SessionStartSource,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: SessionStartInput = {
|
||||
...this.createBaseInput(HookEventName.SessionStart),
|
||||
source,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { trigger: source };
|
||||
return this.executeHooks(HookEventName.SessionStart, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a SessionEnd event
|
||||
*/
|
||||
async fireSessionEndEvent(
|
||||
reason: SessionEndReason,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: SessionEndInput = {
|
||||
...this.createBaseInput(HookEventName.SessionEnd),
|
||||
reason,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { trigger: reason };
|
||||
return this.executeHooks(HookEventName.SessionEnd, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a PreCompact event
|
||||
*/
|
||||
async firePreCompactEvent(
|
||||
trigger: PreCompactTrigger,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: PreCompactInput = {
|
||||
...this.createBaseInput(HookEventName.PreCompact),
|
||||
trigger,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { trigger };
|
||||
return this.executeHooks(HookEventName.PreCompact, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute hooks for a specific event (direct execution without MessageBus)
|
||||
* Used as fallback when MessageBus is not available
|
||||
*/
|
||||
private async executeHooks(
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
context?: HookEventContext,
|
||||
requestContext?: object,
|
||||
): Promise<AggregatedHookResult> {
|
||||
try {
|
||||
// Create execution plan
|
||||
const plan = this.hookPlanner.createExecutionPlan(eventName, context);
|
||||
|
||||
if (!plan || plan.hookConfigs.length === 0) {
|
||||
return {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
};
|
||||
}
|
||||
|
||||
const onHookStart = (_config: HookConfig, _index: number) => {
|
||||
// Hook start event (telemetry removed)
|
||||
};
|
||||
|
||||
const onHookEnd = (_config: HookConfig, _result: HookExecutionResult) => {
|
||||
// Hook end event (telemetry removed)
|
||||
};
|
||||
|
||||
// Execute hooks according to the plan's strategy
|
||||
const results = plan.sequential
|
||||
? await this.hookRunner.executeHooksSequential(
|
||||
plan.hookConfigs,
|
||||
eventName,
|
||||
input,
|
||||
onHookStart,
|
||||
onHookEnd,
|
||||
)
|
||||
: await this.hookRunner.executeHooksParallel(
|
||||
plan.hookConfigs,
|
||||
eventName,
|
||||
input,
|
||||
onHookStart,
|
||||
onHookEnd,
|
||||
);
|
||||
|
||||
// Aggregate results
|
||||
const aggregated = this.hookAggregator.aggregateResults(
|
||||
results,
|
||||
eventName,
|
||||
);
|
||||
|
||||
// Process common hook output fields centrally
|
||||
this.processCommonHookOutputFields(aggregated);
|
||||
|
||||
// Log hook execution
|
||||
this.logHookExecution(
|
||||
eventName,
|
||||
input,
|
||||
results,
|
||||
aggregated,
|
||||
requestContext,
|
||||
);
|
||||
|
||||
return aggregated;
|
||||
} catch (error) {
|
||||
debugLogger.error(`Hook event bus error for ${eventName}: ${error}`);
|
||||
|
||||
return {
|
||||
success: false,
|
||||
allOutputs: [],
|
||||
errors: [error instanceof Error ? error : new Error(String(error))],
|
||||
totalDuration: 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create base hook input with common fields
|
||||
*/
|
||||
private createBaseInput(eventName: HookEventName): HookInput {
|
||||
// Get the transcript path from the Config
|
||||
const transcriptPath = this.config.getTranscriptPath();
|
||||
|
||||
return {
|
||||
session_id: this.config.getSessionId(),
|
||||
transcript_path: transcriptPath,
|
||||
cwd: this.config.getWorkingDir(),
|
||||
hook_event_name: eventName,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Log hook execution for observability
|
||||
*/
|
||||
private logHookExecution(
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
results: HookExecutionResult[],
|
||||
aggregated: AggregatedHookResult,
|
||||
requestContext?: object,
|
||||
): void {
|
||||
const failedHooks = results.filter((r) => !r.success);
|
||||
const successCount = results.length - failedHooks.length;
|
||||
const errorCount = failedHooks.length;
|
||||
|
||||
if (errorCount > 0) {
|
||||
const failedNames = failedHooks
|
||||
.map((r) => this.getHookNameFromResult(r))
|
||||
.join(', ');
|
||||
|
||||
let shouldEmit = true;
|
||||
if (requestContext) {
|
||||
let reportedSet = this.reportedFailures.get(requestContext);
|
||||
if (!reportedSet) {
|
||||
reportedSet = new Set<string>();
|
||||
this.reportedFailures.set(requestContext, reportedSet);
|
||||
}
|
||||
|
||||
const failureKey = `${eventName}:${failedNames}`;
|
||||
if (reportedSet.has(failureKey)) {
|
||||
shouldEmit = false;
|
||||
} else {
|
||||
reportedSet.add(failureKey);
|
||||
}
|
||||
}
|
||||
|
||||
debugLogger.warn(
|
||||
`Hook execution for ${eventName}: ${successCount} succeeded, ${errorCount} failed (${failedNames}), ` +
|
||||
`total duration: ${aggregated.totalDuration}ms`,
|
||||
);
|
||||
|
||||
if (shouldEmit) {
|
||||
debugLogger.warn(
|
||||
`Hook(s) [${failedNames}] failed for event ${eventName}. Check debug logs for more details.`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
debugLogger.debug(
|
||||
`Hook execution for ${eventName}: ${successCount} hooks executed successfully, ` +
|
||||
`total duration: ${aggregated.totalDuration}ms`,
|
||||
);
|
||||
}
|
||||
|
||||
// Log individual errors
|
||||
for (const error of aggregated.errors) {
|
||||
debugLogger.warn(`Hook execution error: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process common hook output fields centrally
|
||||
*/
|
||||
private processCommonHookOutputFields(
|
||||
aggregated: AggregatedHookResult,
|
||||
): void {
|
||||
if (!aggregated.finalOutput) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle systemMessage - show to user in transcript mode (not to agent)
|
||||
const systemMessage = aggregated.finalOutput.systemMessage;
|
||||
if (systemMessage && !aggregated.finalOutput.suppressOutput) {
|
||||
debugLogger.warn(`Hook system message: ${systemMessage}`);
|
||||
}
|
||||
|
||||
// Handle suppressOutput - already handled by not logging above when true
|
||||
|
||||
// Handle continue=false - this should stop the entire agent execution
|
||||
if (aggregated.finalOutput.continue === false) {
|
||||
const stopReason =
|
||||
aggregated.finalOutput.stopReason ||
|
||||
aggregated.finalOutput.reason ||
|
||||
'No reason provided';
|
||||
debugLogger.debug(`Hook requested to stop execution: ${stopReason}`);
|
||||
|
||||
// Note: The actual stopping of execution must be handled by integration points
|
||||
// as they need to interpret this signal in the context of their specific workflow
|
||||
// This is just logging the request centrally
|
||||
}
|
||||
|
||||
// Other common fields like decision/reason are handled by specific hook output classes
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook name from config for display or telemetry
|
||||
*/
|
||||
private getHookName(config: HookConfig): string {
|
||||
return config.name || config.command || 'unknown-command';
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook name from execution result for telemetry
|
||||
*/
|
||||
private getHookNameFromResult(result: HookExecutionResult): string {
|
||||
return this.getHookName(result.hookConfig);
|
||||
}
|
||||
}
|
||||
140
packages/core/src/hooks/hookPlanner.ts
Normal file
140
packages/core/src/hooks/hookPlanner.ts
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { HookRegistry, HookRegistryEntry } from './hookRegistry.js';
|
||||
import type { HookExecutionPlan } from './types.js';
|
||||
import { getHookKey, type HookEventName } from './types.js';
|
||||
|
||||
/**
|
||||
* Hook planner that selects matching hooks and creates execution plans
|
||||
*/
|
||||
export class HookPlanner {
|
||||
private readonly hookRegistry: HookRegistry;
|
||||
|
||||
constructor(hookRegistry: HookRegistry) {
|
||||
this.hookRegistry = hookRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create execution plan for a hook event
|
||||
*/
|
||||
createExecutionPlan(
|
||||
eventName: HookEventName,
|
||||
context?: HookEventContext,
|
||||
): HookExecutionPlan | null {
|
||||
const hookEntries = this.hookRegistry.getHooksForEvent(eventName);
|
||||
|
||||
if (hookEntries.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Filter hooks by matcher
|
||||
const matchingEntries = hookEntries.filter((entry) =>
|
||||
this.matchesContext(entry, context),
|
||||
);
|
||||
|
||||
if (matchingEntries.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Deduplicate identical hooks
|
||||
const deduplicatedEntries = this.deduplicateHooks(matchingEntries);
|
||||
|
||||
// Extract hook configs
|
||||
const hookConfigs = deduplicatedEntries.map((entry) => entry.config);
|
||||
|
||||
// Determine execution strategy - if ANY hook definition has sequential=true, run all sequentially
|
||||
const sequential = deduplicatedEntries.some(
|
||||
(entry) => entry.sequential === true,
|
||||
);
|
||||
|
||||
const plan: HookExecutionPlan = {
|
||||
eventName,
|
||||
hookConfigs,
|
||||
sequential,
|
||||
};
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a hook entry matches the given context
|
||||
*/
|
||||
private matchesContext(
|
||||
entry: HookRegistryEntry,
|
||||
context?: HookEventContext,
|
||||
): boolean {
|
||||
if (!entry.matcher || !context) {
|
||||
return true; // No matcher means match all
|
||||
}
|
||||
|
||||
const matcher = entry.matcher.trim();
|
||||
|
||||
if (matcher === '' || matcher === '*') {
|
||||
return true; // Empty string or wildcard matches all
|
||||
}
|
||||
|
||||
// For tool events, match against tool name
|
||||
if (context.toolName) {
|
||||
return this.matchesToolName(matcher, context.toolName);
|
||||
}
|
||||
|
||||
// For other events, match against trigger/source
|
||||
if (context.trigger) {
|
||||
return this.matchesTrigger(matcher, context.trigger);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Match tool name against matcher pattern
|
||||
*/
|
||||
private matchesToolName(matcher: string, toolName: string): boolean {
|
||||
try {
|
||||
// Attempt to treat the matcher as a regular expression.
|
||||
const regex = new RegExp(matcher);
|
||||
return regex.test(toolName);
|
||||
} catch {
|
||||
// If it's not a valid regex, treat it as a literal string for an exact match.
|
||||
return matcher === toolName;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Match trigger/source against matcher pattern
|
||||
*/
|
||||
private matchesTrigger(matcher: string, trigger: string): boolean {
|
||||
return matcher === trigger;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate identical hook configurations
|
||||
*/
|
||||
private deduplicateHooks(entries: HookRegistryEntry[]): HookRegistryEntry[] {
|
||||
const seen = new Set<string>();
|
||||
const deduplicated: HookRegistryEntry[] = [];
|
||||
|
||||
for (const entry of entries) {
|
||||
const key = getHookKey(entry.config);
|
||||
|
||||
if (!seen.has(key)) {
|
||||
seen.add(key);
|
||||
deduplicated.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
return deduplicated;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Context information for hook event matching
|
||||
*/
|
||||
export interface HookEventContext {
|
||||
toolName?: string;
|
||||
trigger?: string;
|
||||
}
|
||||
337
packages/core/src/hooks/hookRegistry.ts
Normal file
337
packages/core/src/hooks/hookRegistry.ts
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { HookDefinition, HookConfig } from './types.js';
|
||||
import {
|
||||
HookEventName,
|
||||
HooksConfigSource,
|
||||
HOOKS_CONFIG_FIELDS,
|
||||
} from './types.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
import { TrustedHooksManager } from './trustedHooks.js';
|
||||
|
||||
const debugLogger = createDebugLogger('HOOK_REGISTRY');
|
||||
|
||||
/**
|
||||
* Extension with hooks support
|
||||
*/
|
||||
export interface ExtensionWithHooks {
|
||||
isActive: boolean;
|
||||
hooks?: { [K in HookEventName]?: HookDefinition[] };
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration interface for HookRegistry
|
||||
* This abstracts the Config dependency to make the registry more flexible
|
||||
*/
|
||||
export interface HookRegistryConfig {
|
||||
getProjectRoot(): string;
|
||||
isTrustedFolder(): boolean;
|
||||
getHooks(): { [K in HookEventName]?: HookDefinition[] } | undefined;
|
||||
getProjectHooks(): { [K in HookEventName]?: HookDefinition[] } | undefined;
|
||||
getDisabledHooks(): string[];
|
||||
getExtensions(): ExtensionWithHooks[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Feedback emitter interface for warning/info messages
|
||||
*/
|
||||
export interface FeedbackEmitter {
|
||||
emitFeedback(type: 'warning' | 'info' | 'error', message: string): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook registry entry with source information
|
||||
*/
|
||||
export interface HookRegistryEntry {
|
||||
config: HookConfig;
|
||||
source: HooksConfigSource;
|
||||
eventName: HookEventName;
|
||||
matcher?: string;
|
||||
sequential?: boolean;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook registry that loads and validates hook definitions from multiple sources
|
||||
*/
|
||||
export class HookRegistry {
|
||||
private readonly config: HookRegistryConfig;
|
||||
private readonly feedbackEmitter?: FeedbackEmitter;
|
||||
private entries: HookRegistryEntry[] = [];
|
||||
|
||||
constructor(config: HookRegistryConfig, feedbackEmitter?: FeedbackEmitter) {
|
||||
this.config = config;
|
||||
this.feedbackEmitter = feedbackEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the registry by processing hooks from config
|
||||
*/
|
||||
async initialize(): Promise<void> {
|
||||
this.entries = [];
|
||||
this.processHooksFromConfig();
|
||||
|
||||
debugLogger.debug(
|
||||
`Hook registry initialized with ${this.entries.length} hook entries`,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all hook entries for a specific event
|
||||
*/
|
||||
getHooksForEvent(eventName: HookEventName): HookRegistryEntry[] {
|
||||
return this.entries
|
||||
.filter((entry) => entry.eventName === eventName && entry.enabled)
|
||||
.sort(
|
||||
(a, b) =>
|
||||
this.getSourcePriority(a.source) - this.getSourcePriority(b.source),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered hooks
|
||||
*/
|
||||
getAllHooks(): HookRegistryEntry[] {
|
||||
return [...this.entries];
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable a specific hook
|
||||
*/
|
||||
setHookEnabled(hookName: string, enabled: boolean): void {
|
||||
const updated = this.entries.filter((entry) => {
|
||||
const name = this.getHookName(entry);
|
||||
if (name === hookName) {
|
||||
entry.enabled = enabled;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (updated.length > 0) {
|
||||
debugLogger.info(
|
||||
`${enabled ? 'Enabled' : 'Disabled'} ${updated.length} hook(s) matching "${hookName}"`,
|
||||
);
|
||||
} else {
|
||||
debugLogger.warn(`No hooks found matching "${hookName}"`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook name for identification and display purposes
|
||||
*/
|
||||
private getHookName(
|
||||
entry: HookRegistryEntry | { config: HookConfig },
|
||||
): string {
|
||||
return entry.config.name || entry.config.command || 'unknown-command';
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for untrusted project hooks and warn the user
|
||||
*/
|
||||
private checkProjectHooksTrust(): void {
|
||||
const projectHooks = this.config.getProjectHooks();
|
||||
if (!projectHooks) return;
|
||||
|
||||
try {
|
||||
const trustedHooksManager = new TrustedHooksManager();
|
||||
const untrusted = trustedHooksManager.getUntrustedHooks(
|
||||
this.config.getProjectRoot(),
|
||||
projectHooks,
|
||||
);
|
||||
|
||||
if (untrusted.length > 0) {
|
||||
const message = `WARNING: The following project-level hooks have been detected in this workspace:
|
||||
${untrusted.map((h: string) => ` - ${h}`).join('\n')}
|
||||
|
||||
These hooks will be executed. If you did not configure these hooks or do not trust this project,
|
||||
please review the project settings (.qwen/settings.json) and remove them.`;
|
||||
this.feedbackEmitter?.emitFeedback('warning', message);
|
||||
|
||||
// Trust them so we don't warn again
|
||||
trustedHooksManager.trustHooks(
|
||||
this.config.getProjectRoot(),
|
||||
projectHooks,
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
debugLogger.warn('Failed to check project hooks trust');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process hooks from the config that was already loaded by the CLI
|
||||
*/
|
||||
private processHooksFromConfig(): void {
|
||||
if (this.config.isTrustedFolder()) {
|
||||
this.checkProjectHooksTrust();
|
||||
}
|
||||
|
||||
// Get hooks from the main config (this comes from the merged settings)
|
||||
const configHooks = this.config.getHooks();
|
||||
if (configHooks) {
|
||||
if (this.config.isTrustedFolder()) {
|
||||
this.processHooksConfiguration(configHooks, HooksConfigSource.Project);
|
||||
} else {
|
||||
debugLogger.warn(
|
||||
'Project hooks disabled because the folder is not trusted.',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Get hooks from extensions
|
||||
const extensions = this.config.getExtensions() || [];
|
||||
for (const extension of extensions) {
|
||||
if (extension.isActive && extension.hooks) {
|
||||
this.processHooksConfiguration(
|
||||
extension.hooks,
|
||||
HooksConfigSource.Extensions,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process hooks configuration and add entries
|
||||
*/
|
||||
private processHooksConfiguration(
|
||||
hooksConfig: { [K in HookEventName]?: HookDefinition[] },
|
||||
source: HooksConfigSource,
|
||||
): void {
|
||||
for (const [eventName, definitions] of Object.entries(hooksConfig)) {
|
||||
if (HOOKS_CONFIG_FIELDS.includes(eventName)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!this.isValidEventName(eventName)) {
|
||||
this.feedbackEmitter?.emitFeedback(
|
||||
'warning',
|
||||
`Invalid hook event name: "${eventName}" from ${source} config. Skipping.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const typedEventName = eventName;
|
||||
|
||||
if (!Array.isArray(definitions)) {
|
||||
debugLogger.warn(
|
||||
`Hook definitions for event "${eventName}" from source "${source}" is not an array. Skipping.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const definition of definitions) {
|
||||
this.processHookDefinition(definition, typedEventName, source);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a single hook definition
|
||||
*/
|
||||
private processHookDefinition(
|
||||
definition: HookDefinition,
|
||||
eventName: HookEventName,
|
||||
source: HooksConfigSource,
|
||||
): void {
|
||||
if (
|
||||
!definition ||
|
||||
typeof definition !== 'object' ||
|
||||
!Array.isArray(definition.hooks)
|
||||
) {
|
||||
debugLogger.warn(
|
||||
`Discarding invalid hook definition for ${eventName} from ${source}:`,
|
||||
definition,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get disabled hooks list from settings
|
||||
const disabledHooks = this.config.getDisabledHooks();
|
||||
|
||||
for (const hookConfig of definition.hooks) {
|
||||
if (
|
||||
hookConfig &&
|
||||
typeof hookConfig === 'object' &&
|
||||
this.validateHookConfig(hookConfig, eventName, source)
|
||||
) {
|
||||
// Check if this hook is in the disabled list
|
||||
const hookName = this.getHookName({ config: hookConfig });
|
||||
const isDisabled = disabledHooks.includes(hookName);
|
||||
|
||||
// Add source to hook config
|
||||
hookConfig.source = source;
|
||||
|
||||
this.entries.push({
|
||||
config: hookConfig,
|
||||
source,
|
||||
eventName,
|
||||
matcher: definition.matcher,
|
||||
sequential: definition.sequential,
|
||||
enabled: !isDisabled,
|
||||
});
|
||||
} else {
|
||||
// Invalid hooks are logged and discarded here, they won't reach HookRunner
|
||||
debugLogger.warn(
|
||||
`Discarding invalid hook configuration for ${eventName} from ${source}:`,
|
||||
hookConfig,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a hook configuration
|
||||
*/
|
||||
private validateHookConfig(
|
||||
config: HookConfig,
|
||||
eventName: HookEventName,
|
||||
source: HooksConfigSource,
|
||||
): boolean {
|
||||
if (!config.type || !['command', 'plugin'].includes(config.type)) {
|
||||
debugLogger.warn(
|
||||
`Invalid hook ${eventName} from ${source} type: ${config.type}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (config.type === 'command' && !config.command) {
|
||||
debugLogger.warn(
|
||||
`Command hook ${eventName} from ${source} missing command field`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an event name is valid
|
||||
*/
|
||||
private isValidEventName(eventName: string): eventName is HookEventName {
|
||||
const validEventNames: string[] = Object.values(HookEventName);
|
||||
return validEventNames.includes(eventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get source priority (lower number = higher priority)
|
||||
*/
|
||||
private getSourcePriority(source: HooksConfigSource): number {
|
||||
switch (source) {
|
||||
case HooksConfigSource.Project:
|
||||
return 1;
|
||||
case HooksConfigSource.User:
|
||||
return 2;
|
||||
case HooksConfigSource.System:
|
||||
return 3;
|
||||
case HooksConfigSource.Extensions:
|
||||
return 4;
|
||||
default:
|
||||
return 999;
|
||||
}
|
||||
}
|
||||
}
|
||||
451
packages/core/src/hooks/hookRunner.ts
Normal file
451
packages/core/src/hooks/hookRunner.ts
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { spawn } from 'node:child_process';
|
||||
import { HookEventName, HooksConfigSource } from './types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
HookConfig,
|
||||
HookInput,
|
||||
HookOutput,
|
||||
HookExecutionResult,
|
||||
PreToolUseInput,
|
||||
UserPromptSubmitInput,
|
||||
} from './types.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
import {
|
||||
escapeShellArg,
|
||||
getShellConfiguration,
|
||||
type ShellType,
|
||||
} from '../utils/shell-utils.js';
|
||||
|
||||
const debugLogger = createDebugLogger('TRUSTED_HOOKS');
|
||||
|
||||
/**
|
||||
* Default timeout for hook execution (60 seconds)
|
||||
*/
|
||||
const DEFAULT_HOOK_TIMEOUT = 60000;
|
||||
|
||||
/**
|
||||
* Exit code constants for hook execution
|
||||
*/
|
||||
const EXIT_CODE_SUCCESS = 0;
|
||||
const EXIT_CODE_NON_BLOCKING_ERROR = 1;
|
||||
|
||||
/**
|
||||
* Hook runner that executes command hooks
|
||||
*/
|
||||
export class HookRunner {
|
||||
private readonly config: Config;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a single hook
|
||||
*/
|
||||
async executeHook(
|
||||
hookConfig: HookConfig,
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
): Promise<HookExecutionResult> {
|
||||
const startTime = Date.now();
|
||||
|
||||
// Secondary security check: Ensure project hooks are not executed in untrusted folders
|
||||
if (
|
||||
hookConfig.source === HooksConfigSource.Project &&
|
||||
!this.config.isTrustedFolder()
|
||||
) {
|
||||
const errorMessage =
|
||||
'Security: Blocked execution of project hook in untrusted folder';
|
||||
debugLogger.warn(errorMessage);
|
||||
return {
|
||||
hookConfig,
|
||||
eventName,
|
||||
success: false,
|
||||
error: new Error(errorMessage),
|
||||
duration: 0,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.executeCommandHook(
|
||||
hookConfig,
|
||||
eventName,
|
||||
input,
|
||||
startTime,
|
||||
);
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime;
|
||||
const hookId = hookConfig.name || hookConfig.command || 'unknown';
|
||||
const errorMessage = `Hook execution failed for event '${eventName}' (hook: ${hookId}): ${error}`;
|
||||
debugLogger.warn(`Hook execution error (non-fatal): ${errorMessage}`);
|
||||
|
||||
return {
|
||||
hookConfig,
|
||||
eventName,
|
||||
success: false,
|
||||
error: error instanceof Error ? error : new Error(errorMessage),
|
||||
duration,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute multiple hooks in parallel
|
||||
*/
|
||||
async executeHooksParallel(
|
||||
hookConfigs: HookConfig[],
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
onHookStart?: (config: HookConfig, index: number) => void,
|
||||
onHookEnd?: (config: HookConfig, result: HookExecutionResult) => void,
|
||||
): Promise<HookExecutionResult[]> {
|
||||
const promises = hookConfigs.map(async (config, index) => {
|
||||
onHookStart?.(config, index);
|
||||
const result = await this.executeHook(config, eventName, input);
|
||||
onHookEnd?.(config, result);
|
||||
return result;
|
||||
});
|
||||
|
||||
return Promise.all(promises);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute multiple hooks sequentially
|
||||
*/
|
||||
async executeHooksSequential(
|
||||
hookConfigs: HookConfig[],
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
onHookStart?: (config: HookConfig, index: number) => void,
|
||||
onHookEnd?: (config: HookConfig, result: HookExecutionResult) => void,
|
||||
): Promise<HookExecutionResult[]> {
|
||||
const results: HookExecutionResult[] = [];
|
||||
let currentInput = input;
|
||||
|
||||
for (let i = 0; i < hookConfigs.length; i++) {
|
||||
const config = hookConfigs[i];
|
||||
onHookStart?.(config, i);
|
||||
const result = await this.executeHook(config, eventName, currentInput);
|
||||
onHookEnd?.(config, result);
|
||||
results.push(result);
|
||||
|
||||
// If the hook succeeded and has output, use it to modify the input for the next hook
|
||||
if (result.success && result.output) {
|
||||
currentInput = this.applyHookOutputToInput(
|
||||
currentInput,
|
||||
result.output,
|
||||
eventName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply hook output to modify input for the next hook in sequential execution
|
||||
*/
|
||||
private applyHookOutputToInput(
|
||||
originalInput: HookInput,
|
||||
hookOutput: HookOutput,
|
||||
eventName: HookEventName,
|
||||
): HookInput {
|
||||
// Create a copy of the original input
|
||||
const modifiedInput = { ...originalInput };
|
||||
|
||||
// Apply modifications based on hook output and event type
|
||||
if (hookOutput.hookSpecificOutput) {
|
||||
switch (eventName) {
|
||||
case HookEventName.UserPromptSubmit:
|
||||
if ('additionalContext' in hookOutput.hookSpecificOutput) {
|
||||
// For UserPromptSubmit, we could modify the prompt with additional context
|
||||
const additionalContext =
|
||||
hookOutput.hookSpecificOutput['additionalContext'];
|
||||
if (
|
||||
typeof additionalContext === 'string' &&
|
||||
'prompt' in modifiedInput
|
||||
) {
|
||||
(modifiedInput as UserPromptSubmitInput).prompt +=
|
||||
'\n\n' + additionalContext;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case HookEventName.PreToolUse:
|
||||
if ('tool_input' in hookOutput.hookSpecificOutput) {
|
||||
const newToolInput = hookOutput.hookSpecificOutput[
|
||||
'tool_input'
|
||||
] as Record<string, unknown>;
|
||||
if (newToolInput && 'tool_input' in modifiedInput) {
|
||||
(modifiedInput as PreToolUseInput).tool_input = {
|
||||
...(modifiedInput as PreToolUseInput).tool_input,
|
||||
...newToolInput,
|
||||
};
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
// For other events, no special input modification is needed
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return modifiedInput;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a command hook
|
||||
*/
|
||||
private async executeCommandHook(
|
||||
hookConfig: HookConfig,
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
startTime: number,
|
||||
): Promise<HookExecutionResult> {
|
||||
const timeout = hookConfig.timeout ?? DEFAULT_HOOK_TIMEOUT;
|
||||
|
||||
return new Promise((resolve) => {
|
||||
if (!hookConfig.command) {
|
||||
const errorMessage = 'Command hook missing command';
|
||||
debugLogger.warn(
|
||||
`Hook configuration error (non-fatal): ${errorMessage}`,
|
||||
);
|
||||
resolve({
|
||||
hookConfig,
|
||||
eventName,
|
||||
success: false,
|
||||
error: new Error(errorMessage),
|
||||
duration: Date.now() - startTime,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
let timedOut = false;
|
||||
|
||||
const shellConfig = getShellConfiguration();
|
||||
const command = this.expandCommand(
|
||||
hookConfig.command,
|
||||
input,
|
||||
shellConfig.shell,
|
||||
);
|
||||
|
||||
// Set up environment variables
|
||||
// Extract hook-specific fields from input to expose as environment variables
|
||||
const hookEnvVars: Record<string, string> = {};
|
||||
if ('prompt' in input && typeof input.prompt === 'string') {
|
||||
hookEnvVars['PROMPT'] = input.prompt;
|
||||
}
|
||||
if (
|
||||
'prompt_response' in input &&
|
||||
typeof input.prompt_response === 'string'
|
||||
) {
|
||||
hookEnvVars['PROMPT_RESPONSE'] = input.prompt_response;
|
||||
}
|
||||
if ('tool_name' in input && typeof input.tool_name === 'string') {
|
||||
hookEnvVars['TOOL_NAME'] = input.tool_name;
|
||||
}
|
||||
if ('session_id' in input && typeof input.session_id === 'string') {
|
||||
hookEnvVars['SESSION_ID'] = input.session_id;
|
||||
}
|
||||
if (
|
||||
'transcript_path' in input &&
|
||||
typeof input.transcript_path === 'string'
|
||||
) {
|
||||
hookEnvVars['TRANSCRIPT_PATH'] = input.transcript_path;
|
||||
}
|
||||
if (
|
||||
'stop_hook_active' in input &&
|
||||
typeof input.stop_hook_active === 'boolean'
|
||||
) {
|
||||
hookEnvVars['STOP_HOOK_ACTIVE'] = input.stop_hook_active
|
||||
? 'true'
|
||||
: 'false';
|
||||
}
|
||||
|
||||
const env = {
|
||||
...process.env,
|
||||
GEMINI_PROJECT_DIR: input.cwd,
|
||||
CLAUDE_PROJECT_DIR: input.cwd, // For compatibility
|
||||
QWEN_PROJECT_DIR: input.cwd, // For Qwen Code compatibility
|
||||
...hookEnvVars,
|
||||
...hookConfig.env,
|
||||
};
|
||||
|
||||
const child = spawn(
|
||||
shellConfig.executable,
|
||||
[...shellConfig.argsPrefix, command],
|
||||
{
|
||||
env,
|
||||
cwd: input.cwd,
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
shell: false,
|
||||
},
|
||||
);
|
||||
|
||||
// Set up timeout
|
||||
const timeoutHandle = setTimeout(() => {
|
||||
timedOut = true;
|
||||
child.kill('SIGTERM');
|
||||
|
||||
// Force kill after 5 seconds
|
||||
setTimeout(() => {
|
||||
if (!child.killed) {
|
||||
child.kill('SIGKILL');
|
||||
}
|
||||
}, 5000);
|
||||
}, timeout);
|
||||
|
||||
// Send input to stdin
|
||||
if (child.stdin) {
|
||||
child.stdin.on('error', (err: NodeJS.ErrnoException) => {
|
||||
// Ignore EPIPE errors which happen when the child process closes stdin early
|
||||
if (err.code !== 'EPIPE') {
|
||||
debugLogger.debug(`Hook stdin error: ${err}`);
|
||||
}
|
||||
});
|
||||
|
||||
// Wrap write operations in try-catch to handle synchronous EPIPE errors
|
||||
// that occur when the child process exits before we finish writing
|
||||
try {
|
||||
child.stdin.write(JSON.stringify(input));
|
||||
child.stdin.end();
|
||||
} catch (err) {
|
||||
// Ignore EPIPE errors which happen when the child process closes stdin early
|
||||
if (err instanceof Error && 'code' in err && err.code !== 'EPIPE') {
|
||||
debugLogger.debug(`Hook stdin write error: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect stdout
|
||||
child.stdout?.on('data', (data: Buffer) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
// Collect stderr
|
||||
child.stderr?.on('data', (data: Buffer) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
// Handle process exit
|
||||
child.on('close', (exitCode) => {
|
||||
clearTimeout(timeoutHandle);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
if (timedOut) {
|
||||
resolve({
|
||||
hookConfig,
|
||||
eventName,
|
||||
success: false,
|
||||
error: new Error(`Hook timed out after ${timeout}ms`),
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse output
|
||||
let output: HookOutput | undefined;
|
||||
|
||||
const textToParse = stdout.trim() || stderr.trim();
|
||||
if (textToParse) {
|
||||
try {
|
||||
let parsed = JSON.parse(textToParse);
|
||||
if (typeof parsed === 'string') {
|
||||
parsed = JSON.parse(parsed);
|
||||
}
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
output = parsed as HookOutput;
|
||||
}
|
||||
} catch {
|
||||
// Not JSON, convert plain text to structured output
|
||||
output = this.convertPlainTextToHookOutput(
|
||||
textToParse,
|
||||
exitCode || EXIT_CODE_SUCCESS,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
resolve({
|
||||
hookConfig,
|
||||
eventName,
|
||||
success: exitCode === EXIT_CODE_SUCCESS,
|
||||
output,
|
||||
stdout,
|
||||
stderr,
|
||||
exitCode: exitCode || EXIT_CODE_SUCCESS,
|
||||
duration,
|
||||
});
|
||||
});
|
||||
|
||||
// Handle process errors
|
||||
child.on('error', (error) => {
|
||||
clearTimeout(timeoutHandle);
|
||||
const duration = Date.now() - startTime;
|
||||
|
||||
resolve({
|
||||
hookConfig,
|
||||
eventName,
|
||||
success: false,
|
||||
error,
|
||||
stdout,
|
||||
stderr,
|
||||
duration,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand command with environment variables and input context
|
||||
*/
|
||||
private expandCommand(
|
||||
command: string,
|
||||
input: HookInput,
|
||||
shellType: ShellType,
|
||||
): string {
|
||||
debugLogger.debug(`Expanding hook command: ${command} (cwd: ${input.cwd})`);
|
||||
const escapedCwd = escapeShellArg(input.cwd, shellType);
|
||||
return command
|
||||
.replace(/\$GEMINI_PROJECT_DIR/g, () => escapedCwd)
|
||||
.replace(/\$CLAUDE_PROJECT_DIR/g, () => escapedCwd); // For compatibility
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert plain text output to structured HookOutput
|
||||
*/
|
||||
private convertPlainTextToHookOutput(
|
||||
text: string,
|
||||
exitCode: number,
|
||||
): HookOutput {
|
||||
if (exitCode === EXIT_CODE_SUCCESS) {
|
||||
// Success - treat as system message or additional context
|
||||
return {
|
||||
decision: 'allow',
|
||||
systemMessage: text,
|
||||
};
|
||||
} else if (exitCode === EXIT_CODE_NON_BLOCKING_ERROR) {
|
||||
// Non-blocking error (EXIT_CODE_NON_BLOCKING_ERROR = 1)
|
||||
return {
|
||||
decision: 'allow',
|
||||
systemMessage: `Warning: ${text}`,
|
||||
};
|
||||
} else {
|
||||
// All other non-zero exit codes (including 2) are blocking
|
||||
return {
|
||||
decision: 'deny',
|
||||
reason: text,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
270
packages/core/src/hooks/hookSystem.ts
Normal file
270
packages/core/src/hooks/hookSystem.ts
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { HookRegistry } from './hookRegistry.js';
|
||||
import { HookRunner } from './hookRunner.js';
|
||||
import { HookAggregator } from './hookAggregator.js';
|
||||
import { HookPlanner } from './hookPlanner.js';
|
||||
import { HookEventHandler } from './hookEventHandler.js';
|
||||
import type { HookRegistryEntry } from './hookRegistry.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
import type {
|
||||
SessionStartSource,
|
||||
SessionEndReason,
|
||||
PreCompactTrigger,
|
||||
DefaultHookOutput,
|
||||
McpToolContext,
|
||||
} from './types.js';
|
||||
import { NotificationType, createHookOutput } from './types.js';
|
||||
import type { AggregatedHookResult } from './hookAggregator.js';
|
||||
import type { ToolCallConfirmationDetails } from '../tools/tools.js';
|
||||
|
||||
const debugLogger = createDebugLogger('TRUSTED_HOOKS');
|
||||
|
||||
/**
|
||||
* Main hook system that coordinates all hook-related functionality
|
||||
*/
|
||||
|
||||
/**
|
||||
* Converts ToolCallConfirmationDetails to a serializable format for hooks.
|
||||
* Excludes function properties (onConfirm, ideConfirmation) that can't be serialized.
|
||||
*/
|
||||
function toSerializableDetails(
|
||||
details: ToolCallConfirmationDetails,
|
||||
): Record<string, unknown> {
|
||||
const base: Record<string, unknown> = {
|
||||
type: details.type,
|
||||
title: details.title,
|
||||
};
|
||||
|
||||
switch (details.type) {
|
||||
case 'edit':
|
||||
return {
|
||||
...base,
|
||||
fileName: details.fileName,
|
||||
filePath: details.filePath,
|
||||
fileDiff: details.fileDiff,
|
||||
originalContent: details.originalContent,
|
||||
newContent: details.newContent,
|
||||
isModifying: details.isModifying,
|
||||
};
|
||||
case 'exec':
|
||||
return {
|
||||
...base,
|
||||
command: details.command,
|
||||
rootCommand: details.rootCommand,
|
||||
};
|
||||
case 'mcp':
|
||||
return {
|
||||
...base,
|
||||
serverName: details.serverName,
|
||||
toolName: details.toolName,
|
||||
toolDisplayName: details.toolDisplayName,
|
||||
};
|
||||
case 'info':
|
||||
return {
|
||||
...base,
|
||||
prompt: details.prompt,
|
||||
urls: details.urls,
|
||||
};
|
||||
default:
|
||||
return base;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the message to display in the notification hook for tool confirmation.
|
||||
*/
|
||||
function getNotificationMessage(
|
||||
confirmationDetails: ToolCallConfirmationDetails,
|
||||
): string {
|
||||
switch (confirmationDetails.type) {
|
||||
case 'edit':
|
||||
return `Tool ${confirmationDetails.title} requires editing`;
|
||||
case 'exec':
|
||||
return `Tool ${confirmationDetails.title} requires execution`;
|
||||
case 'mcp':
|
||||
return `Tool ${confirmationDetails.title} requires MCP`;
|
||||
case 'info':
|
||||
return `Tool ${confirmationDetails.title} requires information`;
|
||||
default:
|
||||
return `Tool requires confirmation`;
|
||||
}
|
||||
}
|
||||
|
||||
export class HookSystem {
|
||||
private readonly hookRegistry: HookRegistry;
|
||||
private readonly hookRunner: HookRunner;
|
||||
private readonly hookAggregator: HookAggregator;
|
||||
private readonly hookPlanner: HookPlanner;
|
||||
private readonly hookEventHandler: HookEventHandler;
|
||||
|
||||
constructor(config: Config) {
|
||||
// Initialize components
|
||||
this.hookRegistry = new HookRegistry(config);
|
||||
this.hookRunner = new HookRunner(config);
|
||||
this.hookAggregator = new HookAggregator();
|
||||
this.hookPlanner = new HookPlanner(this.hookRegistry);
|
||||
this.hookEventHandler = new HookEventHandler(
|
||||
config,
|
||||
this.hookPlanner,
|
||||
this.hookRunner,
|
||||
this.hookAggregator,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the hook system
|
||||
*/
|
||||
async initialize(): Promise<void> {
|
||||
await this.hookRegistry.initialize();
|
||||
debugLogger.debug('Hook system initialized successfully');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the hook event bus for firing events
|
||||
*/
|
||||
getEventHandler(): HookEventHandler {
|
||||
return this.hookEventHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook registry for management operations
|
||||
*/
|
||||
getRegistry(): HookRegistry {
|
||||
return this.hookRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable a hook
|
||||
*/
|
||||
setHookEnabled(hookName: string, enabled: boolean): void {
|
||||
this.hookRegistry.setHookEnabled(hookName, enabled);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered hooks for display/management
|
||||
*/
|
||||
getAllHooks(): HookRegistryEntry[] {
|
||||
return this.hookRegistry.getAllHooks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire hook events directly
|
||||
*/
|
||||
async fireSessionStartEvent(
|
||||
source: SessionStartSource,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
const result = await this.hookEventHandler.fireSessionStartEvent(source);
|
||||
return result.finalOutput
|
||||
? createHookOutput('SessionStart', result.finalOutput)
|
||||
: undefined;
|
||||
}
|
||||
|
||||
async fireSessionEndEvent(
|
||||
reason: SessionEndReason,
|
||||
): Promise<AggregatedHookResult | undefined> {
|
||||
return this.hookEventHandler.fireSessionEndEvent(reason);
|
||||
}
|
||||
|
||||
async firePreCompactEvent(
|
||||
trigger: PreCompactTrigger,
|
||||
): Promise<AggregatedHookResult | undefined> {
|
||||
return this.hookEventHandler.firePreCompactEvent(trigger);
|
||||
}
|
||||
|
||||
async fireUserPromptSubmitEvent(
|
||||
prompt: string,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
const result =
|
||||
await this.hookEventHandler.fireUserPromptSubmitEvent(prompt);
|
||||
return result.finalOutput
|
||||
? createHookOutput('UserPromptSubmit', result.finalOutput)
|
||||
: undefined;
|
||||
}
|
||||
|
||||
async fireStopEvent(
|
||||
prompt: string,
|
||||
response: string,
|
||||
stopHookActive: boolean = false,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
const result = await this.hookEventHandler.fireStopEvent(
|
||||
prompt,
|
||||
response,
|
||||
stopHookActive,
|
||||
);
|
||||
return result.finalOutput
|
||||
? createHookOutput('Stop', result.finalOutput)
|
||||
: undefined;
|
||||
}
|
||||
|
||||
async firePreToolUseEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const result = await this.hookEventHandler.firePreToolUseEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
);
|
||||
return result.finalOutput
|
||||
? createHookOutput('PreToolUse', result.finalOutput)
|
||||
: undefined;
|
||||
} catch (error) {
|
||||
debugLogger.debug(`PreToolUseEvent failed for ${toolName}:`, error);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
async firePostToolUseEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
toolResponse: {
|
||||
llmContent: unknown;
|
||||
returnDisplay: unknown;
|
||||
error: unknown;
|
||||
},
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const result = await this.hookEventHandler.firePostToolUseEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
toolResponse as Record<string, unknown>,
|
||||
mcpContext,
|
||||
);
|
||||
return result.finalOutput
|
||||
? createHookOutput('PostToolUse', result.finalOutput)
|
||||
: undefined;
|
||||
} catch (error) {
|
||||
debugLogger.debug(`PostToolUseEvent failed for ${toolName}:`, error);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
async fireToolNotificationEvent(
|
||||
confirmationDetails: ToolCallConfirmationDetails,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const message = getNotificationMessage(confirmationDetails);
|
||||
const serializedDetails = toSerializableDetails(confirmationDetails);
|
||||
|
||||
await this.hookEventHandler.fireNotificationEvent(
|
||||
NotificationType.ToolPermission,
|
||||
message,
|
||||
serializedDetails,
|
||||
);
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
`NotificationEvent failed for ${confirmationDetails.title}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
22
packages/core/src/hooks/index.ts
Normal file
22
packages/core/src/hooks/index.ts
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
// Export types
|
||||
export * from './types.js';
|
||||
|
||||
// Export core components
|
||||
export { HookSystem } from './hookSystem.js';
|
||||
export { HookRegistry } from './hookRegistry.js';
|
||||
export { HookRunner } from './hookRunner.js';
|
||||
export { HookAggregator } from './hookAggregator.js';
|
||||
export { HookPlanner } from './hookPlanner.js';
|
||||
export { HookEventHandler } from './hookEventHandler.js';
|
||||
|
||||
// Export interfaces and enums
|
||||
export type { HookRegistryEntry } from './hookRegistry.js';
|
||||
export { HooksConfigSource as ConfigSource } from './types.js';
|
||||
export type { AggregatedHookResult } from './hookAggregator.js';
|
||||
export type { HookEventContext } from './hookPlanner.js';
|
||||
118
packages/core/src/hooks/trustedHooks.ts
Normal file
118
packages/core/src/hooks/trustedHooks.ts
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import {
|
||||
getHookKey,
|
||||
type HookDefinition,
|
||||
type HookEventName,
|
||||
} from './types.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const debugLogger = createDebugLogger('TRUSTED_HOOKS');
|
||||
|
||||
interface TrustedHooksConfig {
|
||||
[projectPath: string]: string[]; // Array of trusted hook keys (name:command)
|
||||
}
|
||||
|
||||
export class TrustedHooksManager {
|
||||
private configPath: string;
|
||||
private trustedHooks: TrustedHooksConfig = {};
|
||||
|
||||
constructor() {
|
||||
this.configPath = path.join(
|
||||
Storage.getGlobalQwenDir(),
|
||||
'trusted_hooks.json',
|
||||
);
|
||||
this.load();
|
||||
}
|
||||
|
||||
private load(): void {
|
||||
try {
|
||||
if (fs.existsSync(this.configPath)) {
|
||||
const content = fs.readFileSync(this.configPath, 'utf-8');
|
||||
this.trustedHooks = JSON.parse(content);
|
||||
}
|
||||
} catch (error) {
|
||||
debugLogger.warn('Failed to load trusted hooks config', error);
|
||||
this.trustedHooks = {};
|
||||
}
|
||||
}
|
||||
|
||||
private save(): void {
|
||||
try {
|
||||
const dir = path.dirname(this.configPath);
|
||||
if (!fs.existsSync(dir)) {
|
||||
fs.mkdirSync(dir, { recursive: true });
|
||||
}
|
||||
fs.writeFileSync(
|
||||
this.configPath,
|
||||
JSON.stringify(this.trustedHooks, null, 2),
|
||||
);
|
||||
} catch (error) {
|
||||
debugLogger.warn('Failed to save trusted hooks config', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get untrusted hooks for a project
|
||||
* @param projectPath Absolute path to the project root
|
||||
* @param hooks The hooks configuration to check
|
||||
* @returns List of untrusted hook commands/names
|
||||
*/
|
||||
getUntrustedHooks(
|
||||
projectPath: string,
|
||||
hooks: { [K in HookEventName]?: HookDefinition[] },
|
||||
): string[] {
|
||||
const trustedKeys = new Set(this.trustedHooks[projectPath] || []);
|
||||
const untrusted: string[] = [];
|
||||
|
||||
for (const eventName of Object.keys(hooks)) {
|
||||
const definitions = hooks[eventName as HookEventName];
|
||||
if (!Array.isArray(definitions)) continue;
|
||||
|
||||
for (const def of definitions) {
|
||||
if (!def || !Array.isArray(def.hooks)) continue;
|
||||
for (const hook of def.hooks) {
|
||||
const key = getHookKey(hook);
|
||||
if (!trustedKeys.has(key)) {
|
||||
// Return friendly name or command
|
||||
untrusted.push(hook.name || hook.command || 'unknown-hook');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(new Set(untrusted)); // Deduplicate
|
||||
}
|
||||
|
||||
/**
|
||||
* Trust all provided hooks for a project
|
||||
*/
|
||||
trustHooks(
|
||||
projectPath: string,
|
||||
hooks: { [K in HookEventName]?: HookDefinition[] },
|
||||
): void {
|
||||
const currentTrusted = new Set(this.trustedHooks[projectPath] || []);
|
||||
|
||||
for (const eventName of Object.keys(hooks)) {
|
||||
const definitions = hooks[eventName as HookEventName];
|
||||
if (!Array.isArray(definitions)) continue;
|
||||
|
||||
for (const def of definitions) {
|
||||
if (!def || !Array.isArray(def.hooks)) continue;
|
||||
for (const hook of def.hooks) {
|
||||
currentTrusted.add(getHookKey(hook));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.trustedHooks[projectPath] = Array.from(currentTrusted);
|
||||
this.save();
|
||||
}
|
||||
}
|
||||
461
packages/core/src/hooks/types.ts
Normal file
461
packages/core/src/hooks/types.ts
Normal file
|
|
@ -0,0 +1,461 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
ToolConfig as GenAIToolConfig,
|
||||
ToolListUnion,
|
||||
} from '@google/genai';
|
||||
export enum HooksConfigSource {
|
||||
Project = 'project',
|
||||
User = 'user',
|
||||
System = 'system',
|
||||
Extensions = 'extensions',
|
||||
}
|
||||
|
||||
/**
|
||||
* Event names for the hook system
|
||||
*/
|
||||
export enum HookEventName {
|
||||
PreToolUse = 'PreToolUse',
|
||||
PostToolUse = 'PostToolUse',
|
||||
UserPromptSubmit = 'UserPromptSubmit',
|
||||
Notification = 'Notification',
|
||||
Stop = 'Stop',
|
||||
SessionStart = 'SessionStart',
|
||||
SessionEnd = 'SessionEnd',
|
||||
PreCompact = 'PreCompact',
|
||||
SubagentStop = 'SubagentStop',
|
||||
PermissionRequest = 'PermissionRequest',
|
||||
}
|
||||
|
||||
/**
|
||||
* Fields in the hooks configuration that are not hook event names
|
||||
*/
|
||||
export const HOOKS_CONFIG_FIELDS = ['enabled', 'disabled', 'notifications'];
|
||||
|
||||
/**
|
||||
* Hook configuration entry
|
||||
*/
|
||||
export interface CommandHookConfig {
|
||||
type: HookType.Command;
|
||||
command: string;
|
||||
name?: string;
|
||||
description?: string;
|
||||
timeout?: number;
|
||||
source?: HooksConfigSource;
|
||||
env?: Record<string, string>;
|
||||
}
|
||||
|
||||
export type HookConfig = CommandHookConfig;
|
||||
|
||||
/**
|
||||
* Hook definition with matcher
|
||||
*/
|
||||
export interface HookDefinition {
|
||||
matcher?: string;
|
||||
sequential?: boolean;
|
||||
hooks: HookConfig[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook implementation types
|
||||
*/
|
||||
export enum HookType {
|
||||
Command = 'command',
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a unique key for a hook configuration
|
||||
*/
|
||||
export function getHookKey(hook: HookConfig): string {
|
||||
const name = hook.name || '';
|
||||
const command = hook.command || '';
|
||||
return `${name}:${command}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decision types for hook outputs
|
||||
*/
|
||||
export type HookDecision =
|
||||
| 'ask'
|
||||
| 'block'
|
||||
| 'deny'
|
||||
| 'approve'
|
||||
| 'allow'
|
||||
| undefined;
|
||||
|
||||
/**
|
||||
* Base hook input - common fields for all events
|
||||
*/
|
||||
export interface HookInput {
|
||||
session_id: string;
|
||||
transcript_path: string;
|
||||
cwd: string;
|
||||
hook_event_name: string;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base hook output - common fields for all events
|
||||
*/
|
||||
export interface HookOutput {
|
||||
continue?: boolean;
|
||||
stopReason?: string;
|
||||
suppressOutput?: boolean;
|
||||
systemMessage?: string;
|
||||
decision?: HookDecision;
|
||||
reason?: string;
|
||||
hookSpecificOutput?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory function to create the appropriate hook output class based on event name
|
||||
* Returns DefaultHookOutput for all events since it contains all necessary methods
|
||||
*/
|
||||
export function createHookOutput(
|
||||
eventName: string,
|
||||
data: Partial<HookOutput>,
|
||||
): DefaultHookOutput {
|
||||
switch (eventName) {
|
||||
case 'PreToolUse':
|
||||
return new PreToolUseHookOutput(data);
|
||||
case 'Stop':
|
||||
return new StopHookOutput(data);
|
||||
default:
|
||||
return new DefaultHookOutput(data);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Default implementation of HookOutput with utility methods
|
||||
*/
|
||||
export class DefaultHookOutput implements HookOutput {
|
||||
continue?: boolean;
|
||||
stopReason?: string;
|
||||
suppressOutput?: boolean;
|
||||
systemMessage?: string;
|
||||
decision?: HookDecision;
|
||||
reason?: string;
|
||||
hookSpecificOutput?: Record<string, unknown>;
|
||||
|
||||
constructor(data: Partial<HookOutput> = {}) {
|
||||
this.continue = data.continue;
|
||||
this.stopReason = data.stopReason;
|
||||
this.suppressOutput = data.suppressOutput;
|
||||
this.systemMessage = data.systemMessage;
|
||||
this.decision = data.decision;
|
||||
this.reason = data.reason;
|
||||
this.hookSpecificOutput = data.hookSpecificOutput;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this output represents a blocking decision
|
||||
*/
|
||||
isBlockingDecision(): boolean {
|
||||
return this.decision === 'block' || this.decision === 'deny';
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this output requests to stop execution
|
||||
*/
|
||||
shouldStopExecution(): boolean {
|
||||
return this.continue === false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the effective reason for blocking or stopping
|
||||
*/
|
||||
getEffectiveReason(): string {
|
||||
return this.stopReason || this.reason || 'No reason provided';
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply tool config modifications (specific method for BeforeToolSelection hooks)
|
||||
*/
|
||||
applyToolConfigModifications(target: {
|
||||
toolConfig?: GenAIToolConfig;
|
||||
tools?: ToolListUnion;
|
||||
}): {
|
||||
toolConfig?: GenAIToolConfig;
|
||||
tools?: ToolListUnion;
|
||||
} {
|
||||
// Base implementation - overridden by BeforeToolSelectionHookOutput
|
||||
return target;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get sanitized additional context for adding to responses.
|
||||
*/
|
||||
getAdditionalContext(): string | undefined {
|
||||
if (
|
||||
this.hookSpecificOutput &&
|
||||
'additionalContext' in this.hookSpecificOutput
|
||||
) {
|
||||
const context = this.hookSpecificOutput['additionalContext'];
|
||||
if (typeof context !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Sanitize by escaping < and > to prevent tag injection
|
||||
return context.replace(/</g, '<').replace(/>/g, '>');
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if execution should be blocked and return error info
|
||||
*/
|
||||
getBlockingError(): { blocked: boolean; reason: string } {
|
||||
if (this.isBlockingDecision()) {
|
||||
return {
|
||||
blocked: true,
|
||||
reason: this.getEffectiveReason(),
|
||||
};
|
||||
}
|
||||
return { blocked: false, reason: '' };
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if context clearing was requested by hook.
|
||||
*/
|
||||
shouldClearContext(): boolean {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Specific hook output class for BeforeTool events.
|
||||
*/
|
||||
export class PreToolUseHookOutput extends DefaultHookOutput {
|
||||
/**
|
||||
* Get modified tool input if provided by hook
|
||||
*/
|
||||
getModifiedToolInput(): Record<string, unknown> | undefined {
|
||||
if (this.hookSpecificOutput && 'tool_input' in this.hookSpecificOutput) {
|
||||
const input = this.hookSpecificOutput['tool_input'];
|
||||
if (
|
||||
typeof input === 'object' &&
|
||||
input !== null &&
|
||||
!Array.isArray(input)
|
||||
) {
|
||||
return input as Record<string, unknown>;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
export class StopHookOutput extends DefaultHookOutput {
|
||||
override stopReason?: string;
|
||||
|
||||
constructor(data: Partial<StopOutput> = {}) {
|
||||
super(data);
|
||||
this.stopReason = data.stopReason;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the stop reason if provided
|
||||
*/
|
||||
getStopReason(): string | undefined {
|
||||
return this.stopReason;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if context clearing was requested by hook
|
||||
*/
|
||||
override shouldClearContext(): boolean {
|
||||
if (this.hookSpecificOutput && 'clearContext' in this.hookSpecificOutput) {
|
||||
return this.hookSpecificOutput['clearContext'] === true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Context for MCP tool executions.
|
||||
* Contains non-sensitive connection information about the MCP server
|
||||
* identity. Since server_name is user controlled and arbitrary, we
|
||||
* also include connection information (e.g., command or url) to
|
||||
* help identify the MCP server.
|
||||
*
|
||||
* NOTE: In the future, consider defining a shared sanitized interface
|
||||
* from MCPServerConfig to avoid duplication and ensure consistency.
|
||||
*/
|
||||
export interface McpToolContext {
|
||||
server_name: string;
|
||||
tool_name: string; // Original tool name from the MCP server
|
||||
|
||||
// Connection info (mutually exclusive based on transport type)
|
||||
command?: string; // For stdio transport
|
||||
args?: string[]; // For stdio transport
|
||||
cwd?: string; // For stdio transport
|
||||
|
||||
url?: string; // For SSE/HTTP transport
|
||||
|
||||
tcp?: string; // For WebSocket transport
|
||||
}
|
||||
|
||||
export interface PreToolUseInput extends HookInput {
|
||||
tool_name: string;
|
||||
tool_input: Record<string, unknown>;
|
||||
mcp_context?: McpToolContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* BeforeTool hook output
|
||||
*/
|
||||
export interface BeforeToolOutput extends HookOutput {
|
||||
hookSpecificOutput?: {
|
||||
hookEventName: 'BeforeTool';
|
||||
tool_input?: Record<string, unknown>;
|
||||
};
|
||||
}
|
||||
export interface PostToolUseInput extends HookInput {
|
||||
tool_name: string;
|
||||
tool_input: Record<string, unknown>;
|
||||
tool_response: Record<string, unknown>;
|
||||
mcp_context?: McpToolContext;
|
||||
}
|
||||
export interface PostToolUseOutput extends HookOutput {
|
||||
hookEventName: 'PostToolUse';
|
||||
}
|
||||
/**
|
||||
* BeforeAgent hook input
|
||||
*/
|
||||
export interface UserPromptSubmitInput extends HookInput {
|
||||
prompt: string;
|
||||
}
|
||||
export interface UserPromptSubmitOutput extends HookOutput {
|
||||
additionalContext?: string;
|
||||
}
|
||||
/**
|
||||
* Notification types
|
||||
*/
|
||||
export enum NotificationType {
|
||||
ToolPermission = 'ToolPermission',
|
||||
}
|
||||
|
||||
/**
|
||||
* Notification hook input
|
||||
*/
|
||||
export interface NotificationInput extends HookInput {
|
||||
notification_type: NotificationType;
|
||||
message: string;
|
||||
details: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Notification hook output
|
||||
*/
|
||||
export interface NotificationOutput {
|
||||
suppressOutput?: boolean;
|
||||
systemMessage?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* AfterAgent hook input
|
||||
*/
|
||||
export interface StopInput extends HookInput {
|
||||
prompt: string;
|
||||
prompt_response: string;
|
||||
stop_hook_active: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop hook output
|
||||
*/
|
||||
export interface StopOutput extends HookOutput {
|
||||
stopReason?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* SessionStart source types
|
||||
*/
|
||||
export enum SessionStartSource {
|
||||
Startup = 'startup',
|
||||
Resume = 'resume',
|
||||
Clear = 'clear',
|
||||
}
|
||||
|
||||
/**
|
||||
* SessionStart hook input
|
||||
*/
|
||||
export interface SessionStartInput extends HookInput {
|
||||
source: SessionStartSource;
|
||||
}
|
||||
|
||||
/**
|
||||
* SessionStart hook output
|
||||
*/
|
||||
export interface SessionStartOutput extends HookOutput {
|
||||
hookSpecificOutput?: {
|
||||
hookEventName: 'SessionStart';
|
||||
additionalContext?: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* SessionEnd reason types
|
||||
*/
|
||||
export enum SessionEndReason {
|
||||
Exit = 'exit',
|
||||
Clear = 'clear',
|
||||
Logout = 'logout',
|
||||
PromptInputExit = 'prompt_input_exit',
|
||||
Other = 'other',
|
||||
}
|
||||
|
||||
/**
|
||||
* SessionEnd hook input
|
||||
*/
|
||||
export interface SessionEndInput extends HookInput {
|
||||
reason: SessionEndReason;
|
||||
}
|
||||
|
||||
/**
|
||||
* PreCompress trigger types
|
||||
*/
|
||||
export enum PreCompactTrigger {
|
||||
Manual = 'manual',
|
||||
Auto = 'auto',
|
||||
}
|
||||
|
||||
/**
|
||||
* PreCompress hook input
|
||||
*/
|
||||
export interface PreCompactInput extends HookInput {
|
||||
trigger: PreCompactTrigger;
|
||||
}
|
||||
|
||||
/**
|
||||
* PreCompress hook output
|
||||
*/
|
||||
export interface PreCompressOutput {
|
||||
suppressOutput?: boolean;
|
||||
systemMessage?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook execution result
|
||||
*/
|
||||
export interface HookExecutionResult {
|
||||
hookConfig: HookConfig;
|
||||
eventName: HookEventName;
|
||||
success: boolean;
|
||||
output?: HookOutput;
|
||||
stdout?: string;
|
||||
stderr?: string;
|
||||
exitCode?: number;
|
||||
duration: number;
|
||||
error?: Error;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook execution plan for an event
|
||||
*/
|
||||
export interface HookExecutionPlan {
|
||||
eventName: HookEventName;
|
||||
hookConfigs: HookConfig[];
|
||||
sequential: boolean;
|
||||
}
|
||||
|
|
@ -298,3 +298,8 @@ export * from './qwen/qwenOAuth2.js';
|
|||
|
||||
export { makeFakeConfig } from './test-utils/config.js';
|
||||
export * from './test-utils/index.js';
|
||||
|
||||
// Export hook types and components
|
||||
export * from './hooks/types.js';
|
||||
export { HookSystem, HookRegistry } from './hooks/index.js';
|
||||
export type { HookRegistryEntry } from './hooks/index.js';
|
||||
|
|
|
|||
541
packages/core/src/policy/policy-engine.ts
Normal file
541
packages/core/src/policy/policy-engine.ts
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
import stableStringify from 'fast-json-stable-stringify';
|
||||
import type { CheckerRunner } from '../safety/checker-runner.js';
|
||||
import { SafetyCheckDecision } from '../safety/protocol.js';
|
||||
import {
|
||||
ApprovalMode,
|
||||
PolicyDecision,
|
||||
type CheckResult,
|
||||
type HookCheckerRule,
|
||||
type PolicyEngineConfig,
|
||||
type PolicyRule,
|
||||
type SafetyCheckerRule,
|
||||
} from './types.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const debugLogger = createDebugLogger('POLICY_ENGINE');
|
||||
|
||||
/**
|
||||
* List of tool names that are considered shell commands.
|
||||
*/
|
||||
const SHELL_TOOL_NAMES = ['run_shell_command', 'shell', 'execute_command'];
|
||||
|
||||
/**
|
||||
* Check if a pattern is a wildcard pattern (contains * or ?).
|
||||
*/
|
||||
function isWildcardPattern(pattern: string): boolean {
|
||||
return pattern.includes('*') || pattern.includes('?');
|
||||
}
|
||||
|
||||
/**
|
||||
* Match a tool name against a wildcard pattern.
|
||||
*/
|
||||
function matchesWildcard(pattern: string, toolName: string): boolean {
|
||||
const regexPattern = pattern
|
||||
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
|
||||
.replace(/\*/g, '.*')
|
||||
.replace(/\?/g, '.');
|
||||
return new RegExp(`^${regexPattern}$`).test(toolName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all aliases for a tool name (for backwards compatibility).
|
||||
*/
|
||||
function getToolAliases(toolName: string): string[] {
|
||||
const aliases: string[] = [toolName];
|
||||
|
||||
// Add common aliases
|
||||
const aliasMap: Record<string, string[]> = {
|
||||
run_shell_command: ['shell', 'execute_command'],
|
||||
shell: ['run_shell_command', 'execute_command'],
|
||||
execute_command: ['run_shell_command', 'shell'],
|
||||
};
|
||||
|
||||
if (aliasMap[toolName]) {
|
||||
aliases.push(...aliasMap[toolName]);
|
||||
}
|
||||
|
||||
return aliases;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a rule matches a tool call.
|
||||
*/
|
||||
function ruleMatches(
|
||||
rule: PolicyRule | SafetyCheckerRule,
|
||||
toolCall: FunctionCall,
|
||||
stringifiedArgs: string | undefined,
|
||||
serverName: string | undefined,
|
||||
approvalMode: ApprovalMode,
|
||||
): boolean {
|
||||
// Check approval mode
|
||||
if ('modes' in rule && rule.modes && rule.modes.length > 0) {
|
||||
if (!rule.modes.includes(approvalMode)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check tool name
|
||||
if (rule.toolName) {
|
||||
const toolName = toolCall.name || '';
|
||||
|
||||
if (isWildcardPattern(rule.toolName)) {
|
||||
if (!matchesWildcard(rule.toolName, toolName)) {
|
||||
return false;
|
||||
}
|
||||
} else if (rule.toolName !== toolName) {
|
||||
// Also check with server prefix
|
||||
if (serverName && rule.toolName !== `${serverName}__${toolName}`) {
|
||||
return false;
|
||||
} else if (!serverName) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check args pattern
|
||||
if (rule.argsPattern && stringifiedArgs) {
|
||||
if (!rule.argsPattern.test(stringifiedArgs)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Policy engine for managing tool execution permissions.
|
||||
*/
|
||||
export class PolicyEngine {
|
||||
private rules: PolicyRule[] = [];
|
||||
private checkers: SafetyCheckerRule[] = [];
|
||||
private hookCheckers: HookCheckerRule[] = [];
|
||||
private readonly defaultDecision: PolicyDecision;
|
||||
private readonly nonInteractive: boolean;
|
||||
private readonly approvalMode: ApprovalMode;
|
||||
private readonly checkerRunner?: CheckerRunner;
|
||||
|
||||
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
|
||||
this.rules = [...(config.rules ?? [])];
|
||||
this.checkers = [...(config.checkers ?? [])];
|
||||
this.hookCheckers = [...(config.hookCheckers ?? [])];
|
||||
this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER;
|
||||
this.nonInteractive = config.nonInteractive ?? false;
|
||||
this.approvalMode = config.approvalMode ?? ApprovalMode.DEFAULT;
|
||||
this.checkerRunner = checkerRunner;
|
||||
|
||||
// Sort rules by priority (higher first)
|
||||
this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check shell command for additional security considerations.
|
||||
*/
|
||||
private async checkShellCommand(
|
||||
toolName: string,
|
||||
command: string | undefined,
|
||||
ruleDecision: PolicyDecision,
|
||||
serverName: string | undefined,
|
||||
shellDirPath: string | undefined,
|
||||
allowRedirection?: boolean,
|
||||
rule?: PolicyRule,
|
||||
): Promise<CheckResult> {
|
||||
let aggregateDecision = ruleDecision;
|
||||
let responsibleRule: PolicyRule | undefined;
|
||||
|
||||
// Check for command redirection
|
||||
if (command && !allowRedirection) {
|
||||
const redirectionPatterns = [
|
||||
/[|&;`$()]/,
|
||||
/>\s*/,
|
||||
/<\s*/,
|
||||
/\$\(/,
|
||||
/`[^`]*`/,
|
||||
];
|
||||
|
||||
for (const pattern of redirectionPatterns) {
|
||||
if (pattern.test(command)) {
|
||||
if (ruleDecision === PolicyDecision.ALLOW) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.checkShellCommand] Downgrading ALLOW to ASK_USER due to redirection pattern: ${pattern}`,
|
||||
);
|
||||
aggregateDecision = PolicyDecision.ASK_USER;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
decision: this.applyNonInteractiveMode(aggregateDecision),
|
||||
// If we stayed at ALLOW, we return the original rule (if any).
|
||||
// If we downgraded, we return the responsible rule (or undefined if implicit).
|
||||
rule: aggregateDecision === ruleDecision ? rule : responsibleRule,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a tool call is allowed based on the configured policies.
|
||||
* Returns the decision and the matching rule (if any).
|
||||
*/
|
||||
async check(
|
||||
toolCall: FunctionCall,
|
||||
serverName: string | undefined,
|
||||
): Promise<CheckResult> {
|
||||
let stringifiedArgs: string | undefined;
|
||||
// Compute stringified args once before the loop
|
||||
if (
|
||||
toolCall.args &&
|
||||
(this.rules.some((rule) => rule.argsPattern) ||
|
||||
this.checkers.some((checker) => checker.argsPattern))
|
||||
) {
|
||||
stringifiedArgs = stableStringify(toolCall.args);
|
||||
}
|
||||
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] toolCall.name: ${toolCall.name}, stringifiedArgs: ${stringifiedArgs}`,
|
||||
);
|
||||
|
||||
// Check for shell commands upfront to handle splitting
|
||||
let isShellCommand = false;
|
||||
let command: string | undefined;
|
||||
let shellDirPath: string | undefined;
|
||||
|
||||
const toolName = toolCall.name;
|
||||
|
||||
if (toolName && SHELL_TOOL_NAMES.includes(toolName)) {
|
||||
isShellCommand = true;
|
||||
|
||||
const args = toolCall.args as { command?: string; dir_path?: string };
|
||||
command = args?.command;
|
||||
shellDirPath = args?.dir_path;
|
||||
}
|
||||
|
||||
// Find the first matching rule (already sorted by priority)
|
||||
let matchedRule: PolicyRule | undefined;
|
||||
let decision: PolicyDecision | undefined;
|
||||
|
||||
// For tools with a server name, we want to try matching both the
|
||||
// original name and the fully qualified name (server__tool).
|
||||
// We also want to check legacy aliases for the tool name.
|
||||
const toolNamesToTry = toolCall.name ? getToolAliases(toolCall.name) : [];
|
||||
|
||||
const toolCallsToTry: FunctionCall[] = [];
|
||||
for (const name of toolNamesToTry) {
|
||||
toolCallsToTry.push({ ...toolCall, name });
|
||||
if (serverName && !name.includes('__')) {
|
||||
toolCallsToTry.push({
|
||||
...toolCall,
|
||||
name: `${serverName}__${name}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const rule of this.rules) {
|
||||
const match = toolCallsToTry.some((tc) =>
|
||||
ruleMatches(rule, tc, stringifiedArgs, serverName, this.approvalMode),
|
||||
);
|
||||
|
||||
if (match) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
|
||||
);
|
||||
|
||||
if (isShellCommand && toolName) {
|
||||
const shellResult = await this.checkShellCommand(
|
||||
toolName,
|
||||
command,
|
||||
rule.decision,
|
||||
serverName,
|
||||
shellDirPath,
|
||||
rule.allowRedirection,
|
||||
rule,
|
||||
);
|
||||
decision = shellResult.decision;
|
||||
if (shellResult.rule) {
|
||||
matchedRule = shellResult.rule;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
decision = this.applyNonInteractiveMode(rule.decision);
|
||||
matchedRule = rule;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default if no rule matched
|
||||
if (decision === undefined) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`,
|
||||
);
|
||||
if (toolName && SHELL_TOOL_NAMES.includes(toolName)) {
|
||||
const shellResult = await this.checkShellCommand(
|
||||
toolName,
|
||||
command,
|
||||
this.defaultDecision,
|
||||
serverName,
|
||||
shellDirPath,
|
||||
);
|
||||
decision = shellResult.decision;
|
||||
matchedRule = shellResult.rule;
|
||||
} else {
|
||||
decision = this.applyNonInteractiveMode(this.defaultDecision);
|
||||
}
|
||||
}
|
||||
|
||||
// Safety checks
|
||||
if (decision !== PolicyDecision.DENY && this.checkerRunner) {
|
||||
for (const checkerRule of this.checkers) {
|
||||
if (
|
||||
ruleMatches(
|
||||
checkerRule,
|
||||
toolCall,
|
||||
stringifiedArgs,
|
||||
serverName,
|
||||
this.approvalMode,
|
||||
)
|
||||
) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`,
|
||||
);
|
||||
try {
|
||||
const result = await this.checkerRunner.runChecker(
|
||||
toolCall,
|
||||
checkerRule.checker,
|
||||
);
|
||||
if (result.decision === SafetyCheckDecision.DENY) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Safety checker '${checkerRule.checker.name}' denied execution: ${result.reason}`,
|
||||
);
|
||||
return {
|
||||
decision: PolicyDecision.DENY,
|
||||
rule: matchedRule,
|
||||
};
|
||||
} else if (result.decision === SafetyCheckDecision.ASK_USER) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Safety checker requested ASK_USER: ${result.reason}`,
|
||||
);
|
||||
decision = PolicyDecision.ASK_USER;
|
||||
}
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Safety checker '${checkerRule.checker.name}' threw an error:`,
|
||||
error,
|
||||
);
|
||||
return {
|
||||
decision: PolicyDecision.DENY,
|
||||
rule: matchedRule,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
decision: this.applyNonInteractiveMode(decision),
|
||||
rule: matchedRule,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new rule to the policy engine.
|
||||
*/
|
||||
addRule(rule: PolicyRule): void {
|
||||
this.rules.push(rule);
|
||||
// Re-sort rules by priority
|
||||
this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
addChecker(checker: SafetyCheckerRule): void {
|
||||
this.checkers.push(checker);
|
||||
this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove rules matching a specific tier (priority band).
|
||||
*/
|
||||
removeRulesByTier(tier: number): void {
|
||||
this.rules = this.rules.filter(
|
||||
(rule) => Math.floor(rule.priority ?? 0) !== tier,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove checkers matching a specific tier (priority band).
|
||||
*/
|
||||
removeCheckersByTier(tier: number): void {
|
||||
this.checkers = this.checkers.filter(
|
||||
(checker) => Math.floor(checker.priority ?? 0) !== tier,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove rules for a specific tool.
|
||||
* If source is provided, only rules matching that source are removed.
|
||||
*/
|
||||
removeRulesForTool(toolName: string, source?: string): void {
|
||||
this.rules = this.rules.filter(
|
||||
(rule) =>
|
||||
rule.toolName !== toolName ||
|
||||
(source !== undefined && rule.source !== source),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all current rules.
|
||||
*/
|
||||
getRules(): readonly PolicyRule[] {
|
||||
return this.rules;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a rule for a specific tool already exists.
|
||||
* If ignoreDynamic is true, it only returns true if a rule exists that was NOT added by AgentRegistry.
|
||||
*/
|
||||
hasRuleForTool(toolName: string, ignoreDynamic = false): boolean {
|
||||
return this.rules.some(
|
||||
(rule) =>
|
||||
rule.toolName === toolName &&
|
||||
(!ignoreDynamic || rule.source !== 'AgentRegistry (Dynamic)'),
|
||||
);
|
||||
}
|
||||
|
||||
getCheckers(): readonly SafetyCheckerRule[] {
|
||||
return this.checkers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new hook checker to the policy engine.
|
||||
*/
|
||||
addHookChecker(checker: HookCheckerRule): void {
|
||||
this.hookCheckers.push(checker);
|
||||
this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all current hook checkers.
|
||||
*/
|
||||
getHookCheckers(): readonly HookCheckerRule[] {
|
||||
return this.hookCheckers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a hook execution is allowed based on the configured policies.
|
||||
* Returns the decision for the hook execution request.
|
||||
*/
|
||||
async checkHook(hookRequest: {
|
||||
eventName: string;
|
||||
input: Record<string, unknown>;
|
||||
}): Promise<PolicyDecision> {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.checkHook] eventName: ${hookRequest.eventName}`,
|
||||
);
|
||||
|
||||
// For now, allow all hooks by default
|
||||
// In the future, this can be extended to check hook-specific policies
|
||||
return this.applyNonInteractiveMode(PolicyDecision.ALLOW);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get tools that are effectively denied by the current rules.
|
||||
* This takes into account:
|
||||
* 1. Global rules (no argsPattern)
|
||||
* 2. Priority order (higher priority wins)
|
||||
* 3. Non-interactive mode (ASK_USER becomes DENY)
|
||||
*/
|
||||
getExcludedTools(): Set<string> {
|
||||
const excludedTools = new Set<string>();
|
||||
const processedTools = new Set<string>();
|
||||
let globalVerdict: PolicyDecision | undefined;
|
||||
|
||||
for (const rule of this.rules) {
|
||||
if (rule.argsPattern) {
|
||||
if (rule.toolName && rule.decision !== PolicyDecision.DENY) {
|
||||
processedTools.add(rule.toolName);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if rule applies to current approval mode
|
||||
if (rule.modes && rule.modes.length > 0) {
|
||||
if (!rule.modes.includes(this.approvalMode)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Global Rules
|
||||
if (!rule.toolName) {
|
||||
if (globalVerdict === undefined) {
|
||||
globalVerdict = rule.decision;
|
||||
if (globalVerdict !== PolicyDecision.DENY) {
|
||||
// Global ALLOW/ASK found.
|
||||
// Since rules are sorted by priority, this overrides any lower-priority rules.
|
||||
// We can stop processing because nothing else will be excluded.
|
||||
break;
|
||||
}
|
||||
// If Global DENY, we continue to find specific tools to add to excluded set
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolName = rule.toolName;
|
||||
|
||||
// Check if already processed (exact match)
|
||||
if (processedTools.has(toolName)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if covered by a processed wildcard
|
||||
let coveredByWildcard = false;
|
||||
for (const processed of processedTools) {
|
||||
if (
|
||||
isWildcardPattern(processed) &&
|
||||
matchesWildcard(processed, toolName)
|
||||
) {
|
||||
// It's covered by a higher-priority wildcard rule.
|
||||
// If that wildcard rule resulted in exclusion, this tool should also be excluded.
|
||||
if (excludedTools.has(processed)) {
|
||||
excludedTools.add(toolName);
|
||||
}
|
||||
coveredByWildcard = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (coveredByWildcard) {
|
||||
continue;
|
||||
}
|
||||
|
||||
processedTools.add(toolName);
|
||||
|
||||
// Determine decision
|
||||
let decision: PolicyDecision;
|
||||
if (globalVerdict !== undefined) {
|
||||
decision = globalVerdict;
|
||||
} else {
|
||||
decision = rule.decision;
|
||||
}
|
||||
|
||||
if (decision === PolicyDecision.DENY) {
|
||||
excludedTools.add(toolName);
|
||||
}
|
||||
}
|
||||
return excludedTools;
|
||||
}
|
||||
|
||||
private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision {
|
||||
// In non-interactive mode, ASK_USER becomes DENY
|
||||
if (this.nonInteractive && decision === PolicyDecision.ASK_USER) {
|
||||
return PolicyDecision.DENY;
|
||||
}
|
||||
return decision;
|
||||
}
|
||||
}
|
||||
293
packages/core/src/policy/types.ts
Normal file
293
packages/core/src/policy/types.ts
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { SafetyCheckInput } from '../safety/protocol.js';
|
||||
|
||||
export enum PolicyDecision {
|
||||
ALLOW = 'allow',
|
||||
DENY = 'deny',
|
||||
ASK_USER = 'ask_user',
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid sources for hook execution
|
||||
*/
|
||||
export type HookSource = 'project' | 'user' | 'system' | 'extension';
|
||||
|
||||
/**
|
||||
* Array of valid hook source values for runtime validation
|
||||
*/
|
||||
const VALID_HOOK_SOURCES: HookSource[] = [
|
||||
'project',
|
||||
'user',
|
||||
'system',
|
||||
'extension',
|
||||
];
|
||||
|
||||
/**
|
||||
* Safely extract and validate hook source from input
|
||||
* Returns 'project' as default if the value is invalid or missing
|
||||
*/
|
||||
export function getHookSource(input: Record<string, unknown>): HookSource {
|
||||
const source = input['hook_source'];
|
||||
if (
|
||||
typeof source === 'string' &&
|
||||
VALID_HOOK_SOURCES.includes(source as HookSource)
|
||||
) {
|
||||
return source as HookSource;
|
||||
}
|
||||
return 'project';
|
||||
}
|
||||
|
||||
export enum ApprovalMode {
|
||||
DEFAULT = 'default',
|
||||
AUTO_EDIT = 'autoEdit',
|
||||
YOLO = 'yolo',
|
||||
PLAN = 'plan',
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for the built-in allowed-path checker.
|
||||
*/
|
||||
export interface AllowedPathConfig {
|
||||
/**
|
||||
* Explicitly include argument keys to be checked as paths.
|
||||
*/
|
||||
included_args?: string[];
|
||||
|
||||
/**
|
||||
* Explicitly exclude argument keys from being checked as paths.
|
||||
*/
|
||||
excluded_args?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Base interface for external checkers.
|
||||
*/
|
||||
export interface ExternalCheckerConfig {
|
||||
type: 'external';
|
||||
name: string;
|
||||
config?: unknown;
|
||||
required_context?: Array<keyof SafetyCheckInput['context']>;
|
||||
}
|
||||
|
||||
export enum InProcessCheckerType {
|
||||
ALLOWED_PATH = 'allowed-path',
|
||||
}
|
||||
|
||||
/**
|
||||
* Base interface for in-process checkers.
|
||||
*/
|
||||
export interface InProcessCheckerConfig {
|
||||
type: 'in-process';
|
||||
name: InProcessCheckerType;
|
||||
config?: AllowedPathConfig;
|
||||
required_context?: Array<keyof SafetyCheckInput['context']>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A discriminated union for all safety checker configurations.
|
||||
*/
|
||||
export type SafetyCheckerConfig =
|
||||
| ExternalCheckerConfig
|
||||
| InProcessCheckerConfig;
|
||||
|
||||
export interface PolicyRule {
|
||||
/**
|
||||
* A unique name for the policy rule, useful for identification and debugging.
|
||||
*/
|
||||
name?: string;
|
||||
|
||||
/**
|
||||
* The name of the tool this rule applies to.
|
||||
* If undefined, the rule applies to all tools.
|
||||
*/
|
||||
toolName?: string;
|
||||
|
||||
/**
|
||||
* Pattern to match against tool arguments.
|
||||
* Can be used for more fine-grained control.
|
||||
*/
|
||||
argsPattern?: RegExp;
|
||||
|
||||
/**
|
||||
* The decision to make when this rule matches.
|
||||
*/
|
||||
decision: PolicyDecision;
|
||||
|
||||
/**
|
||||
* Priority of this rule. Higher numbers take precedence.
|
||||
* Default is 0.
|
||||
*/
|
||||
priority?: number;
|
||||
|
||||
/**
|
||||
* Approval modes this rule applies to.
|
||||
* If undefined or empty, it applies to all modes.
|
||||
*/
|
||||
modes?: ApprovalMode[];
|
||||
|
||||
/**
|
||||
* If true, allows command redirection even if the policy engine would normally
|
||||
* downgrade ALLOW to ASK_USER for redirected commands.
|
||||
* Only applies when decision is ALLOW.
|
||||
*/
|
||||
allowRedirection?: boolean;
|
||||
|
||||
/**
|
||||
* Effect of the rule's source.
|
||||
* e.g. "my-policies.toml", "Settings (MCP Trusted)", etc.
|
||||
*/
|
||||
source?: string;
|
||||
|
||||
/**
|
||||
* Optional message to display when this rule results in a DENY decision.
|
||||
* This message will be returned to the model/user.
|
||||
*/
|
||||
denyMessage?: string;
|
||||
}
|
||||
|
||||
export interface SafetyCheckerRule {
|
||||
/**
|
||||
* The name of the tool this rule applies to.
|
||||
* If undefined, the rule applies to all tools.
|
||||
*/
|
||||
toolName?: string;
|
||||
|
||||
/**
|
||||
* Pattern to match against tool arguments.
|
||||
* Can be used for more fine-grained control.
|
||||
*/
|
||||
argsPattern?: RegExp;
|
||||
|
||||
/**
|
||||
* Priority of this checker. Higher numbers run first.
|
||||
* Default is 0.
|
||||
*/
|
||||
priority?: number;
|
||||
|
||||
/**
|
||||
* Specifies an external or built-in safety checker to execute for
|
||||
* additional validation of a tool call.
|
||||
*/
|
||||
checker: SafetyCheckerConfig;
|
||||
|
||||
/**
|
||||
* Approval modes this rule applies to.
|
||||
* If undefined or empty, it applies to all modes.
|
||||
*/
|
||||
modes?: ApprovalMode[];
|
||||
|
||||
/**
|
||||
* Source of the rule.
|
||||
* e.g. "my-policies.toml", "Workspace: project.toml", etc.
|
||||
*/
|
||||
source?: string;
|
||||
}
|
||||
|
||||
export interface HookExecutionContext {
|
||||
eventName: string;
|
||||
hookSource?: HookSource;
|
||||
trustedFolder?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Rule for applying safety checkers to hook executions.
|
||||
* Similar to SafetyCheckerRule but with hook-specific matching criteria.
|
||||
*/
|
||||
export interface HookCheckerRule {
|
||||
/**
|
||||
* The name of the hook event this rule applies to.
|
||||
* If undefined, the rule applies to all hook events.
|
||||
*/
|
||||
eventName?: string;
|
||||
|
||||
/**
|
||||
* The source of hooks this rule applies to.
|
||||
* If undefined, the rule applies to all hook sources.
|
||||
*/
|
||||
hookSource?: HookSource;
|
||||
|
||||
/**
|
||||
* Priority of this checker. Higher numbers run first.
|
||||
* Default is 0.
|
||||
*/
|
||||
priority?: number;
|
||||
|
||||
/**
|
||||
* Specifies an external or built-in safety checker to execute for
|
||||
* additional validation of a hook execution.
|
||||
*/
|
||||
checker: SafetyCheckerConfig;
|
||||
}
|
||||
|
||||
export interface PolicyEngineConfig {
|
||||
/**
|
||||
* List of policy rules to apply.
|
||||
*/
|
||||
rules?: PolicyRule[];
|
||||
|
||||
/**
|
||||
* List of safety checkers to apply to tool calls.
|
||||
*/
|
||||
checkers?: SafetyCheckerRule[];
|
||||
|
||||
/**
|
||||
* List of safety checkers to apply to hook executions.
|
||||
*/
|
||||
hookCheckers?: HookCheckerRule[];
|
||||
|
||||
/**
|
||||
* Default decision when no rules match.
|
||||
* Defaults to ASK_USER.
|
||||
*/
|
||||
defaultDecision?: PolicyDecision;
|
||||
|
||||
/**
|
||||
* Whether to allow tools in non-interactive mode.
|
||||
* When true, ASK_USER decisions become DENY.
|
||||
*/
|
||||
nonInteractive?: boolean;
|
||||
|
||||
/**
|
||||
* Whether to allow hooks to execute.
|
||||
* When false, all hooks are denied.
|
||||
* Defaults to true.
|
||||
*/
|
||||
allowHooks?: boolean;
|
||||
|
||||
/**
|
||||
* Current approval mode.
|
||||
* Used to filter rules that have specific 'modes' defined.
|
||||
*/
|
||||
approvalMode?: ApprovalMode;
|
||||
}
|
||||
|
||||
export interface PolicySettings {
|
||||
mcp?: {
|
||||
excluded?: string[];
|
||||
allowed?: string[];
|
||||
};
|
||||
tools?: {
|
||||
exclude?: string[];
|
||||
allowed?: string[];
|
||||
};
|
||||
mcpServers?: Record<string, { trust?: boolean }>;
|
||||
// User provided policies that will replace the USER level policies in ~/.gemini/policies
|
||||
policyPaths?: string[];
|
||||
workspacePoliciesDir?: string;
|
||||
}
|
||||
|
||||
export interface CheckResult {
|
||||
decision: PolicyDecision;
|
||||
rule?: PolicyRule;
|
||||
}
|
||||
|
||||
/**
|
||||
* Priority for subagent tools (registered dynamically).
|
||||
* Effective priority matching Tier 1 (Default) read-only tools.
|
||||
*/
|
||||
export const PRIORITY_SUBAGENT_TOOL = 1.05;
|
||||
155
packages/core/src/safety/built-in.ts
Normal file
155
packages/core/src/safety/built-in.ts
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as path from 'node:path';
|
||||
import * as fs from 'node:fs';
|
||||
import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js';
|
||||
import { SafetyCheckDecision } from './protocol.js';
|
||||
import type { AllowedPathConfig } from '../policy/types.js';
|
||||
|
||||
/**
|
||||
* Interface for all in-process safety checkers.
|
||||
*/
|
||||
export interface InProcessChecker {
|
||||
check(input: SafetyCheckInput): Promise<SafetyCheckResult>;
|
||||
}
|
||||
|
||||
/**
|
||||
* An in-process checker to validate file paths.
|
||||
*/
|
||||
export class AllowedPathChecker implements InProcessChecker {
|
||||
async check(input: SafetyCheckInput): Promise<SafetyCheckResult> {
|
||||
const { toolCall, context } = input;
|
||||
|
||||
const config = input.config as AllowedPathConfig | undefined;
|
||||
|
||||
// Build list of allowed directories
|
||||
const allowedDirs = [
|
||||
context.environment.cwd,
|
||||
...context.environment.workspaces,
|
||||
];
|
||||
|
||||
// Find all arguments that look like paths
|
||||
const includedArgs = config?.included_args ?? [];
|
||||
const excludedArgs = config?.excluded_args ?? [];
|
||||
|
||||
const pathsToCheck = this.collectPathsToCheck(
|
||||
toolCall.args,
|
||||
includedArgs,
|
||||
excludedArgs,
|
||||
);
|
||||
|
||||
// Check each path
|
||||
for (const { path: p, argName } of pathsToCheck) {
|
||||
const resolvedPath = this.safelyResolvePath(p, context.environment.cwd);
|
||||
|
||||
if (!resolvedPath) {
|
||||
// If path cannot be resolved, deny it
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Cannot resolve path "${p}" in argument "${argName}"`,
|
||||
};
|
||||
}
|
||||
|
||||
const isAllowed = allowedDirs.some((dir) => {
|
||||
// Also resolve allowed directories to handle symlinks
|
||||
const resolvedDir = this.safelyResolvePath(
|
||||
dir,
|
||||
context.environment.cwd,
|
||||
);
|
||||
if (!resolvedDir) return false;
|
||||
return this.isPathAllowed(resolvedPath, resolvedDir);
|
||||
});
|
||||
|
||||
if (!isAllowed) {
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Path "${p}" in argument "${argName}" is outside of the allowed workspace directories.`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return { decision: SafetyCheckDecision.ALLOW };
|
||||
}
|
||||
|
||||
private safelyResolvePath(inputPath: string, cwd: string): string | null {
|
||||
try {
|
||||
const resolved = path.resolve(cwd, inputPath);
|
||||
|
||||
// Walk up the directory tree until we find a path that exists
|
||||
let current = resolved;
|
||||
// Stop at root (dirname(root) === root on many systems, or it becomes empty/'.' depending on implementation)
|
||||
while (current && current !== path.dirname(current)) {
|
||||
if (fs.existsSync(current)) {
|
||||
const canonical = fs.realpathSync(current);
|
||||
// Re-construct the full path from this canonical base
|
||||
const relative = path.relative(current, resolved);
|
||||
// path.join handles empty relative paths correctly (returns canonical)
|
||||
return path.join(canonical, relative);
|
||||
}
|
||||
current = path.dirname(current);
|
||||
}
|
||||
|
||||
// Fallback if nothing exists (unlikely if root exists)
|
||||
return resolved;
|
||||
} catch (_error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private isPathAllowed(targetPath: string, allowedDir: string): boolean {
|
||||
const relative = path.relative(allowedDir, targetPath);
|
||||
return (
|
||||
relative === '' ||
|
||||
(!relative.startsWith('..') && !path.isAbsolute(relative))
|
||||
);
|
||||
}
|
||||
|
||||
private collectPathsToCheck(
|
||||
args: unknown,
|
||||
includedArgs: string[],
|
||||
excludedArgs: string[],
|
||||
prefix = '',
|
||||
): Array<{ path: string; argName: string }> {
|
||||
const paths: Array<{ path: string; argName: string }> = [];
|
||||
|
||||
if (typeof args !== 'object' || args === null) {
|
||||
return paths;
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(args)) {
|
||||
const fullKey = prefix ? `${prefix}.${key}` : key;
|
||||
|
||||
if (excludedArgs.includes(fullKey)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (typeof value === 'string') {
|
||||
if (
|
||||
includedArgs.includes(fullKey) ||
|
||||
key.includes('path') ||
|
||||
key.includes('directory') ||
|
||||
key.includes('file') ||
|
||||
key === 'source' ||
|
||||
key === 'destination'
|
||||
) {
|
||||
paths.push({ path: value, argName: fullKey });
|
||||
}
|
||||
} else if (typeof value === 'object') {
|
||||
paths.push(
|
||||
...this.collectPathsToCheck(
|
||||
value,
|
||||
includedArgs,
|
||||
excludedArgs,
|
||||
fullKey,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return paths;
|
||||
}
|
||||
}
|
||||
305
packages/core/src/safety/checker-runner.ts
Normal file
305
packages/core/src/safety/checker-runner.ts
Normal file
|
|
@ -0,0 +1,305 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { spawn } from 'node:child_process';
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
import type {
|
||||
SafetyCheckerConfig,
|
||||
InProcessCheckerConfig,
|
||||
ExternalCheckerConfig,
|
||||
} from '../policy/types.js';
|
||||
import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js';
|
||||
import { SafetyCheckDecision } from './protocol.js';
|
||||
import type { CheckerRegistry } from './registry.js';
|
||||
import type { ContextBuilder } from './context-builder.js';
|
||||
import { z } from 'zod';
|
||||
|
||||
const SafetyCheckResultSchema: z.ZodType<SafetyCheckResult> =
|
||||
z.discriminatedUnion('decision', [
|
||||
z.object({
|
||||
decision: z.literal(SafetyCheckDecision.ALLOW),
|
||||
reason: z.string().optional(),
|
||||
}),
|
||||
z.object({
|
||||
decision: z.literal(SafetyCheckDecision.DENY),
|
||||
reason: z.string().min(1),
|
||||
}),
|
||||
z.object({
|
||||
decision: z.literal(SafetyCheckDecision.ASK_USER),
|
||||
reason: z.string().min(1),
|
||||
}),
|
||||
]);
|
||||
|
||||
/**
|
||||
* Configuration for the checker runner.
|
||||
*/
|
||||
export interface CheckerRunnerConfig {
|
||||
/**
|
||||
* Maximum time (in milliseconds) to wait for a checker to complete.
|
||||
* Default: 5000 (5 seconds)
|
||||
*/
|
||||
timeout?: number;
|
||||
|
||||
/**
|
||||
* Path to the directory containing external checkers.
|
||||
*/
|
||||
checkersPath: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Service for executing safety checker processes.
|
||||
*/
|
||||
export class CheckerRunner {
|
||||
private static readonly DEFAULT_TIMEOUT = 5000; // 5 seconds
|
||||
|
||||
private readonly registry: CheckerRegistry;
|
||||
private readonly contextBuilder: ContextBuilder;
|
||||
private readonly timeout: number;
|
||||
|
||||
constructor(
|
||||
contextBuilder: ContextBuilder,
|
||||
registry: CheckerRegistry,
|
||||
config: CheckerRunnerConfig,
|
||||
) {
|
||||
this.contextBuilder = contextBuilder;
|
||||
this.registry = registry;
|
||||
this.timeout = config.timeout ?? CheckerRunner.DEFAULT_TIMEOUT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a safety checker and returns the result.
|
||||
*/
|
||||
async runChecker(
|
||||
toolCall: FunctionCall,
|
||||
checkerConfig: SafetyCheckerConfig,
|
||||
): Promise<SafetyCheckResult> {
|
||||
if (checkerConfig.type === 'in-process') {
|
||||
return this.runInProcessChecker(toolCall, checkerConfig);
|
||||
}
|
||||
return this.runExternalChecker(toolCall, checkerConfig);
|
||||
}
|
||||
|
||||
private async runInProcessChecker(
|
||||
toolCall: FunctionCall,
|
||||
checkerConfig: InProcessCheckerConfig,
|
||||
): Promise<SafetyCheckResult> {
|
||||
try {
|
||||
const checker = this.registry.resolveInProcess(checkerConfig.name);
|
||||
const context = checkerConfig.required_context
|
||||
? this.contextBuilder.buildMinimalContext(
|
||||
checkerConfig.required_context,
|
||||
)
|
||||
: this.contextBuilder.buildFullContext();
|
||||
|
||||
const input: SafetyCheckInput = {
|
||||
protocolVersion: '1.0.0',
|
||||
toolCall,
|
||||
context,
|
||||
config: checkerConfig.config,
|
||||
};
|
||||
|
||||
// In-process checkers can be async, but we'll also apply a timeout
|
||||
// for safety, in case of infinite loops or unexpected delays.
|
||||
return await this.executeWithTimeout(checker.check(input));
|
||||
} catch (error) {
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to run in-process checker "${checkerConfig.name}": ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async runExternalChecker(
|
||||
toolCall: FunctionCall,
|
||||
checkerConfig: ExternalCheckerConfig,
|
||||
): Promise<SafetyCheckResult> {
|
||||
try {
|
||||
// Resolve the checker executable path
|
||||
const checkerPath = this.registry.resolveExternal(checkerConfig.name);
|
||||
|
||||
// Build the appropriate context
|
||||
const context = checkerConfig.required_context
|
||||
? this.contextBuilder.buildMinimalContext(
|
||||
checkerConfig.required_context,
|
||||
)
|
||||
: this.contextBuilder.buildFullContext();
|
||||
|
||||
// Create the input payload
|
||||
const input: SafetyCheckInput = {
|
||||
protocolVersion: '1.0.0',
|
||||
toolCall,
|
||||
context,
|
||||
config: checkerConfig.config,
|
||||
};
|
||||
|
||||
// Run the checker process
|
||||
return await this.executeCheckerProcess(
|
||||
checkerPath,
|
||||
input,
|
||||
checkerConfig.name,
|
||||
);
|
||||
} catch (error) {
|
||||
// If anything goes wrong, deny the operation
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to run safety checker "${checkerConfig.name}": ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes an external checker process and handles its lifecycle.
|
||||
*/
|
||||
private executeCheckerProcess(
|
||||
checkerPath: string,
|
||||
input: SafetyCheckInput,
|
||||
checkerName: string,
|
||||
): Promise<SafetyCheckResult> {
|
||||
return new Promise((resolve) => {
|
||||
const child = spawn(checkerPath, [], {
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
});
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
let timeoutHandle: NodeJS.Timeout | null = null;
|
||||
let killed = false;
|
||||
|
||||
let exited = false;
|
||||
|
||||
// Set up timeout
|
||||
timeoutHandle = setTimeout(() => {
|
||||
killed = true;
|
||||
child.kill('SIGTERM');
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Safety checker "${checkerName}" timed out after ${this.timeout}ms`,
|
||||
});
|
||||
|
||||
// Fallback: if process doesn't exit after 5s, force kill
|
||||
setTimeout(() => {
|
||||
if (!exited) {
|
||||
child.kill('SIGKILL');
|
||||
}
|
||||
}, 5000).unref();
|
||||
}, this.timeout);
|
||||
|
||||
// Collect output
|
||||
if (child.stdout) {
|
||||
child.stdout.on('data', (data: Buffer) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
}
|
||||
|
||||
if (child.stderr) {
|
||||
child.stderr.on('data', (data: Buffer) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
}
|
||||
|
||||
// Handle process completion
|
||||
child.on('close', (code: number | null) => {
|
||||
exited = true;
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
// If we already killed it due to timeout, don't process the result
|
||||
if (killed) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-zero exit code is a failure
|
||||
if (code !== 0) {
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Safety checker "${checkerName}" exited with code ${code}${
|
||||
stderr ? `: ${stderr}` : ''
|
||||
}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to parse the output
|
||||
try {
|
||||
const rawResult = JSON.parse(stdout);
|
||||
const result = SafetyCheckResultSchema.parse(rawResult);
|
||||
|
||||
resolve(result);
|
||||
} catch (parseError) {
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to parse output from safety checker "${checkerName}": ${
|
||||
parseError instanceof Error
|
||||
? parseError.message
|
||||
: String(parseError)
|
||||
}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Handle process errors
|
||||
child.on('error', (error: Error) => {
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
if (!killed) {
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to spawn safety checker "${checkerName}": ${error.message}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Send input to the checker
|
||||
try {
|
||||
if (child.stdin) {
|
||||
child.stdin.write(JSON.stringify(input));
|
||||
child.stdin.end();
|
||||
} else {
|
||||
throw new Error('Failed to open stdin for checker process');
|
||||
}
|
||||
} catch (writeError) {
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
child.kill();
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to write to stdin of safety checker "${checkerName}": ${
|
||||
writeError instanceof Error
|
||||
? writeError.message
|
||||
: String(writeError)
|
||||
}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a promise with a timeout.
|
||||
*/
|
||||
private executeWithTimeout<T>(promise: Promise<T>): Promise<T> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const timeoutHandle = setTimeout(() => {
|
||||
reject(new Error(`Checker timed out after ${this.timeout}ms`));
|
||||
}, this.timeout);
|
||||
|
||||
promise
|
||||
.then(resolve)
|
||||
.catch(reject)
|
||||
.finally(() => {
|
||||
clearTimeout(timeoutHandle);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
55
packages/core/src/safety/context-builder.ts
Normal file
55
packages/core/src/safety/context-builder.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { SafetyCheckInput, ConversationTurn } from './protocol.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
/**
|
||||
* Builds context objects for safety checkers, ensuring sensitive data is filtered.
|
||||
*/
|
||||
export class ContextBuilder {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly conversationHistory: ConversationTurn[] = [],
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Builds the full context object with all available data.
|
||||
*/
|
||||
buildFullContext(): SafetyCheckInput['context'] {
|
||||
return {
|
||||
environment: {
|
||||
cwd: process.cwd(),
|
||||
|
||||
workspaces: this.config
|
||||
.getWorkspaceContext()
|
||||
.getDirectories() as string[],
|
||||
},
|
||||
history: {
|
||||
turns: this.conversationHistory,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a minimal context with only the specified keys.
|
||||
*/
|
||||
buildMinimalContext(
|
||||
requiredKeys: Array<keyof SafetyCheckInput['context']>,
|
||||
): SafetyCheckInput['context'] {
|
||||
const fullContext = this.buildFullContext();
|
||||
const minimalContext: Partial<SafetyCheckInput['context']> = {};
|
||||
|
||||
for (const key of requiredKeys) {
|
||||
if (key in fullContext) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(minimalContext as any)[key] = fullContext[key];
|
||||
}
|
||||
}
|
||||
|
||||
return minimalContext as SafetyCheckInput['context'];
|
||||
}
|
||||
}
|
||||
100
packages/core/src/safety/protocol.ts
Normal file
100
packages/core/src/safety/protocol.ts
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
|
||||
/**
|
||||
* Represents a single turn in the conversation between the user and the model.
|
||||
* This provides semantic context for why a tool call might be happening.
|
||||
*/
|
||||
export interface ConversationTurn {
|
||||
user: {
|
||||
text: string;
|
||||
};
|
||||
model: {
|
||||
text?: string;
|
||||
toolCalls?: FunctionCall[];
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* The data structure passed from the CLI to a safety checker process via stdin.
|
||||
*/
|
||||
export interface SafetyCheckInput {
|
||||
/**
|
||||
* The semantic version of the protocol (e.g., "1.0.0"). This allows
|
||||
* for introducing breaking changes in the future while maintaining
|
||||
* support for older checkers.
|
||||
*/
|
||||
protocolVersion: '1.0.0';
|
||||
|
||||
/**
|
||||
* The specific tool call that is being validated.
|
||||
*/
|
||||
toolCall: FunctionCall;
|
||||
|
||||
/**
|
||||
* A container for all contextual information from the CLI's internal state.
|
||||
* By grouping data into categories, we can easily add new context in the
|
||||
* future without creating a flat, unmanageable object.
|
||||
*/
|
||||
context: {
|
||||
/**
|
||||
* Information about the user's file system and execution environment.
|
||||
*/
|
||||
environment: {
|
||||
cwd: string;
|
||||
workspaces: string[]; // A list of user-configured workspace roots
|
||||
};
|
||||
|
||||
/**
|
||||
* The recent history of the conversation. This can be used by checkers
|
||||
* that need to understand the intent behind a tool call.
|
||||
*/
|
||||
history?: {
|
||||
turns: ConversationTurn[];
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Configuration for the safety checker.
|
||||
* This allows checkers to be parameterized (e.g. allowed paths).
|
||||
*/
|
||||
config?: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* The possible decisions a safety checker can make.
|
||||
*/
|
||||
export enum SafetyCheckDecision {
|
||||
ALLOW = 'allow',
|
||||
DENY = 'deny',
|
||||
ASK_USER = 'ask_user',
|
||||
}
|
||||
|
||||
/**
|
||||
* The data structure returned by a safety checker process via stdout.
|
||||
*/
|
||||
export type SafetyCheckResult =
|
||||
| {
|
||||
/**
|
||||
* The decision made by the safety checker.
|
||||
*/
|
||||
decision: SafetyCheckDecision.ALLOW;
|
||||
/**
|
||||
* If not allowed, a message explaining why the tool call was blocked.
|
||||
* This will be shown to the user.
|
||||
*/
|
||||
reason?: string;
|
||||
}
|
||||
| {
|
||||
decision: SafetyCheckDecision.DENY;
|
||||
reason: string;
|
||||
}
|
||||
| {
|
||||
decision: SafetyCheckDecision.ASK_USER;
|
||||
reason: string;
|
||||
};
|
||||
83
packages/core/src/safety/registry.ts
Normal file
83
packages/core/src/safety/registry.ts
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as path from 'node:path';
|
||||
import * as fs from 'node:fs';
|
||||
import { type InProcessChecker, AllowedPathChecker } from './built-in.js';
|
||||
import { InProcessCheckerType } from '../policy/types.js';
|
||||
|
||||
/**
|
||||
* Registry for managing safety checker resolution.
|
||||
*/
|
||||
export class CheckerRegistry {
|
||||
private static readonly BUILT_IN_EXTERNAL_CHECKERS = new Map<string, string>([
|
||||
// No external built-ins for now
|
||||
]);
|
||||
|
||||
private static readonly BUILT_IN_IN_PROCESS_CHECKERS = new Map<
|
||||
string,
|
||||
InProcessChecker
|
||||
>([[InProcessCheckerType.ALLOWED_PATH, new AllowedPathChecker()]]);
|
||||
|
||||
// Regex to validate checker names (alphanumeric and hyphens only)
|
||||
private static readonly VALID_NAME_PATTERN = /^[a-z0-9-]+$/;
|
||||
|
||||
constructor(private readonly checkersPath: string) {}
|
||||
|
||||
/**
|
||||
* Resolves an external checker name to an absolute executable path.
|
||||
*/
|
||||
resolveExternal(name: string): string {
|
||||
if (!CheckerRegistry.isValidCheckerName(name)) {
|
||||
throw new Error(
|
||||
`Invalid checker name "${name}". Checker names must contain only lowercase letters, numbers, and hyphens.`,
|
||||
);
|
||||
}
|
||||
|
||||
const builtInPath = CheckerRegistry.BUILT_IN_EXTERNAL_CHECKERS.get(name);
|
||||
if (builtInPath) {
|
||||
const fullPath = path.join(this.checkersPath, builtInPath);
|
||||
if (!fs.existsSync(fullPath)) {
|
||||
throw new Error(`Built-in checker "${name}" not found at ${fullPath}`);
|
||||
}
|
||||
return fullPath;
|
||||
}
|
||||
|
||||
// TODO: Phase 5 - Add support for custom external checkers
|
||||
throw new Error(`Unknown external checker "${name}".`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves an in-process checker name to a checker instance.
|
||||
*/
|
||||
resolveInProcess(name: string): InProcessChecker {
|
||||
if (!CheckerRegistry.isValidCheckerName(name)) {
|
||||
throw new Error(`Invalid checker name "${name}".`);
|
||||
}
|
||||
|
||||
const checker = CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.get(name);
|
||||
if (checker) {
|
||||
return checker;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Unknown in-process checker "${name}". Available: ${Array.from(
|
||||
CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.keys(),
|
||||
).join(', ')}`,
|
||||
);
|
||||
}
|
||||
|
||||
private static isValidCheckerName(name: string): boolean {
|
||||
return this.VALID_NAME_PATTERN.test(name) && !name.includes('..');
|
||||
}
|
||||
|
||||
static getBuiltInCheckers(): string[] {
|
||||
return [
|
||||
...Array.from(this.BUILT_IN_EXTERNAL_CHECKERS.keys()),
|
||||
...Array.from(this.BUILT_IN_IN_PROCESS_CHECKERS.keys()),
|
||||
];
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue