diff --git a/docs/users/configuration/settings.md b/docs/users/configuration/settings.md index 6d753dfbb..3c2c394c0 100644 --- a/docs/users/configuration/settings.md +++ b/docs/users/configuration/settings.md @@ -527,7 +527,7 @@ For authentication-related variables (like `OPENAI_*`) and the recommended `.qwe | `CODE_ASSIST_ENDPOINT` | Specifies the endpoint for the code assist server. | This is useful for development and testing. | | `QWEN_CODE_MAX_OUTPUT_TOKENS` | Overrides the default maximum output tokens per response. When not set, Qwen Code uses an adaptive strategy: starts with 8K tokens and automatically retries with 64K if the response is truncated. Set this to a specific value (e.g., `16000`) to use a fixed limit instead. | Takes precedence over the capped default (8K) but is overridden by `samplingParams.max_tokens` in settings. Disables automatic escalation when set. Example: `export QWEN_CODE_MAX_OUTPUT_TOKENS=16000` | | `TAVILY_API_KEY` | Your API key for the Tavily web search service. | Used to enable the `web_search` tool functionality. Example: `export TAVILY_API_KEY="tvly-your-api-key-here"` | -| `QWEN_CODE_PROFILE_STARTUP` | Set to `1` to enable startup performance profiling. Writes a JSON timing report to `~/.qwen/startup-perf/` with per-phase durations. | Only active inside the sandbox child process. Zero overhead when not set. Example: `export QWEN_CODE_PROFILE_STARTUP=1` | +| `QWEN_CODE_PROFILE_STARTUP` | Set to `1` to enable startup performance profiling. Writes a JSON timing report to `~/.qwen/startup-perf/` with per-phase durations. | Only active inside the sandbox child process. Zero overhead when not set. Example: `export QWEN_CODE_PROFILE_STARTUP=1` | ## Command-Line Arguments diff --git a/packages/cli/src/acp-integration/acpAgent.test.ts b/packages/cli/src/acp-integration/acpAgent.test.ts index 07473e97d..261035970 100644 --- a/packages/cli/src/acp-integration/acpAgent.test.ts +++ b/packages/cli/src/acp-integration/acpAgent.test.ts @@ -4,7 +4,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + afterAll, + type MockInstance, +} from 'vitest'; // Mock cleanup module before importing anything else const { mockRunExitCleanup } = vi.hoisted(() => ({ @@ -56,6 +65,7 @@ vi.mock('@qwen-code/qwen-code-core', () => ({ debug: vi.fn(), error: vi.fn(), warn: vi.fn(), + info: vi.fn(), }), APPROVAL_MODE_INFO: {}, APPROVAL_MODES: [], @@ -66,6 +76,14 @@ vi.mock('@qwen-code/qwen-code-core', () => ({ MCPServerConfig: {}, SessionService: vi.fn(), tokenLimit: vi.fn(), + SessionStartSource: { + Startup: 'startup', + Resume: 'resume', + }, + SessionEndReason: { + PromptInputExit: 'prompt_input_exit', + Other: 'other', + }, })); vi.mock('./authMethods.js', () => ({ buildAuthMethods: vi.fn() })); @@ -83,26 +101,39 @@ import { runAcpAgent } from './acpAgent.js'; import type { Config } from '@qwen-code/qwen-code-core'; import type { LoadedSettings } from '../config/settings.js'; import type { CliArgs } from '../config/config.js'; +import { SessionEndReason } from '@qwen-code/qwen-code-core'; describe('runAcpAgent shutdown cleanup', () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let processExitSpy: any; + let processExitSpy: MockInstance; + let processOnSpy: MockInstance; + let processOffSpy: MockInstance; + let stdinDestroySpy: MockInstance; + let stdoutDestroySpy: MockInstance; let sigTermListeners: NodeJS.SignalsListener[]; let sigIntListeners: NodeJS.SignalsListener[]; + let mockConfig: Config; - const mockConfig = {} as Config; const mockSettings = { merged: {} } as LoadedSettings; const mockArgv = {} as CliArgs; beforeEach(() => { vi.clearAllMocks(); + // Reset mockConfig after clearAllMocks + mockConfig = { + initialize: vi.fn().mockResolvedValue(undefined), + getHookSystem: vi.fn().mockReturnValue(undefined), + getDisableAllHooks: vi.fn().mockReturnValue(false), + hasHooksForEvent: vi.fn().mockReturnValue(false), + getModel: vi.fn().mockReturnValue('test-model'), + } as unknown as Config; + mockRunExitCleanup.mockResolvedValue(undefined); mockConnectionState.reset(); sigTermListeners = []; sigIntListeners = []; // Intercept signal handler registration - vi.spyOn(process, 'on').mockImplementation((( + processOnSpy = vi.spyOn(process, 'on').mockImplementation((( event: string, listener: (...args: unknown[]) => void, ) => { @@ -113,9 +144,18 @@ describe('runAcpAgent shutdown cleanup', () => { return process; }) as typeof process.on); - vi.spyOn(process, 'off').mockImplementation( - (() => process) as typeof process.off, - ); + processOffSpy = vi.spyOn(process, 'off').mockImplementation((( + event: string, + listener: (...args: unknown[]) => void, + ) => { + if (event === 'SIGTERM') { + sigTermListeners = sigTermListeners.filter((l) => l !== listener); + } + if (event === 'SIGINT') { + sigIntListeners = sigIntListeners.filter((l) => l !== listener); + } + return process; + }) as typeof process.off); // Mock process.exit to prevent actually exiting processExitSpy = vi @@ -123,23 +163,36 @@ describe('runAcpAgent shutdown cleanup', () => { .mockImplementation((() => undefined) as unknown as typeof process.exit); // Mock stdin/stdout destroy - vi.spyOn(process.stdin, 'destroy').mockImplementation(() => process.stdin); - vi.spyOn(process.stdout, 'destroy').mockImplementation( - () => process.stdout, - ); + stdinDestroySpy = vi + .spyOn(process.stdin, 'destroy') + .mockImplementation(() => process.stdin); + stdoutDestroySpy = vi + .spyOn(process.stdout, 'destroy') + .mockImplementation(() => process.stdout); }); afterEach(() => { processExitSpy.mockRestore(); - vi.restoreAllMocks(); + stdinDestroySpy.mockRestore(); + stdoutDestroySpy.mockRestore(); + vi.clearAllMocks(); + }); + + afterAll(() => { + processOnSpy.mockRestore(); + processOffSpy.mockRestore(); }); it('calls runExitCleanup and process.exit on SIGTERM', async () => { // Start runAcpAgent (it will await connection.closed) const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + // Wait for signal handlers to be registered + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + // Simulate SIGTERM from IDE - expect(sigTermListeners.length).toBeGreaterThan(0); sigTermListeners[0]('SIGTERM'); // runExitCleanup is async, wait for it @@ -159,7 +212,11 @@ describe('runAcpAgent shutdown cleanup', () => { it('calls runExitCleanup and process.exit on SIGINT', async () => { const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); - expect(sigIntListeners.length).toBeGreaterThan(0); + // Wait for signal handlers to be registered + await vi.waitFor(() => { + expect(sigIntListeners.length).toBeGreaterThan(0); + }); + sigIntListeners[0]('SIGINT'); await vi.waitFor(() => { @@ -177,6 +234,11 @@ describe('runAcpAgent shutdown cleanup', () => { it('only runs shutdown once even if multiple signals arrive', async () => { const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + // Wait for signal handlers to be registered + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + // Send SIGTERM twice sigTermListeners[0]('SIGTERM'); sigTermListeners[0]('SIGTERM'); @@ -194,6 +256,11 @@ describe('runAcpAgent shutdown cleanup', () => { const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + // Wait for signal handlers to be registered + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + sigTermListeners[0]('SIGTERM'); // process.exit should still be called via .finally() @@ -205,3 +272,211 @@ describe('runAcpAgent shutdown cleanup', () => { await agentPromise; }); }); + +describe('runAcpAgent SessionEnd hooks', () => { + let processExitSpy: MockInstance; + let processOnSpy: MockInstance; + let processOffSpy: MockInstance; + let stdinDestroySpy: MockInstance; + let stdoutDestroySpy: MockInstance; + let sigTermListeners: NodeJS.SignalsListener[]; + let sigIntListeners: NodeJS.SignalsListener[]; + let mockConfig: Config; + let mockHookSystem: { + fireSessionEndEvent: ReturnType; + fireSessionStartEvent: ReturnType; + }; + + const mockSettings = { merged: {} } as LoadedSettings; + const mockArgv = {} as CliArgs; + + beforeEach(() => { + vi.clearAllMocks(); + mockHookSystem = { + fireSessionEndEvent: vi.fn().mockResolvedValue(undefined), + fireSessionStartEvent: vi.fn().mockResolvedValue(undefined), + }; + mockConfig = { + initialize: vi.fn().mockResolvedValue(undefined), + getHookSystem: vi.fn().mockReturnValue(mockHookSystem), + getDisableAllHooks: vi.fn().mockReturnValue(false), + hasHooksForEvent: vi.fn().mockReturnValue(true), + getModel: vi.fn().mockReturnValue('test-model'), + } as unknown as Config; + + mockRunExitCleanup.mockResolvedValue(undefined); + mockConnectionState.reset(); + sigTermListeners = []; + sigIntListeners = []; + + processOnSpy = vi.spyOn(process, 'on').mockImplementation((( + event: string, + listener: (...args: unknown[]) => void, + ) => { + if (event === 'SIGTERM') + sigTermListeners.push(listener as NodeJS.SignalsListener); + if (event === 'SIGINT') + sigIntListeners.push(listener as NodeJS.SignalsListener); + return process; + }) as typeof process.on); + + processOffSpy = vi.spyOn(process, 'off').mockImplementation((( + event: string, + listener: (...args: unknown[]) => void, + ) => { + if (event === 'SIGTERM') { + sigTermListeners = sigTermListeners.filter((l) => l !== listener); + } + if (event === 'SIGINT') { + sigIntListeners = sigIntListeners.filter((l) => l !== listener); + } + return process; + }) as typeof process.off); + + processExitSpy = vi + .spyOn(process, 'exit') + .mockImplementation((() => undefined) as unknown as typeof process.exit); + + stdinDestroySpy = vi + .spyOn(process.stdin, 'destroy') + .mockImplementation(() => process.stdin); + stdoutDestroySpy = vi + .spyOn(process.stdout, 'destroy') + .mockImplementation(() => process.stdout); + }); + + afterEach(() => { + processExitSpy.mockRestore(); + stdinDestroySpy.mockRestore(); + stdoutDestroySpy.mockRestore(); + vi.clearAllMocks(); + }); + + afterAll(() => { + processOnSpy.mockRestore(); + processOffSpy.mockRestore(); + }); + + it('fires SessionEnd hook with Other reason on SIGTERM', async () => { + const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + + sigTermListeners[0]('SIGTERM'); + + await vi.waitFor(() => { + expect(mockHookSystem.fireSessionEndEvent).toHaveBeenCalledWith( + SessionEndReason.Other, + ); + }); + + mockConnectionState.resolve(); + await agentPromise; + }); + + it('fires SessionEnd hook with Other reason on SIGINT', async () => { + const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + + await vi.waitFor(() => { + expect(sigIntListeners.length).toBeGreaterThan(0); + }); + + sigIntListeners[0]('SIGINT'); + + await vi.waitFor(() => { + expect(mockHookSystem.fireSessionEndEvent).toHaveBeenCalledWith( + SessionEndReason.Other, + ); + }); + + mockConnectionState.resolve(); + await agentPromise; + }); + + it('fires SessionEnd hook with PromptInputExit on connection.closed', async () => { + const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + + // Resolve connection to simulate IDE disconnect + mockConnectionState.resolve(); + + await vi.waitFor(() => { + expect(mockHookSystem.fireSessionEndEvent).toHaveBeenCalledWith( + SessionEndReason.PromptInputExit, + ); + }); + + await agentPromise; + }); + + it('does not fire SessionEnd hook when hooks are disabled', async () => { + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(true); + + const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + + sigTermListeners[0]('SIGTERM'); + + await vi.waitFor(() => { + expect(mockRunExitCleanup).toHaveBeenCalled(); + }); + + // SessionEnd hook should NOT be called + expect(mockHookSystem.fireSessionEndEvent).not.toHaveBeenCalled(); + + mockConnectionState.resolve(); + await agentPromise; + }); + + it('does not fire SessionEnd hook when event not registered', async () => { + mockConfig.hasHooksForEvent = vi.fn().mockReturnValue(false); + + const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + + sigTermListeners[0]('SIGTERM'); + + await vi.waitFor(() => { + expect(mockRunExitCleanup).toHaveBeenCalled(); + }); + + // SessionEnd hook should NOT be called + expect(mockHookSystem.fireSessionEndEvent).not.toHaveBeenCalled(); + + mockConnectionState.resolve(); + await agentPromise; + }); + + it('fires SessionEnd hook only once when SIGTERM triggers before connection.closed', async () => { + const agentPromise = runAcpAgent(mockConfig, mockSettings, mockArgv); + + await vi.waitFor(() => { + expect(sigTermListeners.length).toBeGreaterThan(0); + }); + + // Trigger SIGTERM first + sigTermListeners[0]('SIGTERM'); + + await vi.waitFor(() => { + expect(mockHookSystem.fireSessionEndEvent).toHaveBeenCalledWith( + SessionEndReason.Other, + ); + }); + + // Now resolve connection.closed - this should NOT trigger another SessionEnd + mockConnectionState.resolve(); + + // Wait for the agent to complete + await agentPromise; + + // SessionEnd should have been called exactly once + expect(mockHookSystem.fireSessionEndEvent).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/cli/src/acp-integration/acpAgent.ts b/packages/cli/src/acp-integration/acpAgent.ts index 4e290b2e1..a1e81eae4 100644 --- a/packages/cli/src/acp-integration/acpAgent.ts +++ b/packages/cli/src/acp-integration/acpAgent.ts @@ -18,6 +18,9 @@ import { type Config, type ConversationRecord, type DeviceAuthorizationData, + SessionStartSource, + SessionEndReason, + type PermissionMode, } from '@qwen-code/qwen-code-core'; import { AgentSideConnection, @@ -74,6 +77,10 @@ export async function runAcpAgent( settings: LoadedSettings, argv: CliArgs, ) { + // Initialize config to set up hookSystem (required for SessionStart/SessionEnd hooks) + // This is needed because gemini.tsx calls runAcpAgent without calling config.initialize() + await config.initialize(); + const stdout = Writable.toWeb(process.stdout) as WritableStream; const stdin = Readable.toWeb(process.stdin) as ReadableStream; @@ -94,10 +101,34 @@ export async function runAcpAgent( // (e.g., stdin raw mode restoration) override the default exit behavior, // causing the ACP process to ignore termination signals. let shuttingDown = false; - const shutdownHandler = () => { + let sessionEndFired = false; + + // Helper to fire SessionEnd hook once, preventing double-fire from both + // shutdown handler path and connection.closed path. + const fireSessionEndOnce = async (reason: SessionEndReason) => { + if (sessionEndFired) return; + sessionEndFired = true; + const hookSystem = config.getHookSystem?.(); + const hooksEnabled = !config.getDisableAllHooks?.(); + if (hooksEnabled && hookSystem && config.hasHooksForEvent?.('SessionEnd')) { + try { + await hookSystem.fireSessionEndEvent(reason); + } catch (err) { + debugLogger.warn( + `SessionEnd hook failed: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } + }; + + const shutdownHandler = async () => { if (shuttingDown) return; shuttingDown = true; debugLogger.debug('[ACP] Shutdown signal received, closing streams'); + + // Fire SessionEnd hook for all active sessions (aligned with core path) + await fireSessionEndOnce(SessionEndReason.Other); + try { process.stdin.destroy(); } catch { @@ -123,6 +154,8 @@ export async function runAcpAgent( process.on('SIGINT', shutdownHandler); await connection.closed; + // Connection closed by IDE - fire SessionEnd hook (aligned with core path) + await fireSessionEndOnce(SessionEndReason.PromptInputExit); process.off('SIGTERM', shutdownHandler); process.off('SIGINT', shutdownHandler); @@ -518,6 +551,24 @@ class QwenAgent implements Agent { ); this.sessions.set(sessionId, session); + // Fire SessionStart hook (aligned with core path) + const hookSystem = config.getHookSystem(); + const hooksEnabled = !config.getDisableAllHooks(); + if (hooksEnabled && hookSystem && config.hasHooksForEvent('SessionStart')) { + const source = conversation + ? SessionStartSource.Resume + : SessionStartSource.Startup; + const model = config.getModel(); + const permissionMode = String(config.getApprovalMode()) as PermissionMode; + try { + await hookSystem.fireSessionStartEvent(source, model, permissionMode); + } catch (err) { + debugLogger.warn( + `SessionStart hook failed: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } + setTimeout(async () => { await session.sendAvailableCommandsUpdate(); }, 0); diff --git a/packages/cli/src/acp-integration/session/Session.test.ts b/packages/cli/src/acp-integration/session/Session.test.ts index 4eb8093ad..1c5e1c7d4 100644 --- a/packages/cli/src/acp-integration/session/Session.test.ts +++ b/packages/cli/src/acp-integration/session/Session.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import * as fs from 'node:fs/promises'; import * as os from 'node:os'; import * as path from 'node:path'; @@ -24,6 +24,22 @@ vi.mock('../../nonInteractiveCliCommands.js', () => ({ handleSlashCommand: vi.fn(), })); +// Helper to create empty async generator (avoids memory leak from inline generators) +function createEmptyStream() { + return (async function* () {})(); +} + +// Helper to create async generator with chunks (avoids memory leak) +function createStreamWithChunks( + chunks: Array<{ type: unknown; value: unknown }>, +) { + return (async function* () { + for (const chunk of chunks) { + yield chunk; + } + })(); +} + describe('Session', () => { let mockChat: GeminiChat; let mockConfig: Config; @@ -49,6 +65,7 @@ describe('Session', () => { mockChat = { sendMessageStream: vi.fn(), addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), } as unknown as GeminiChat; mockToolRegistry = { getTool: vi.fn() }; @@ -103,6 +120,22 @@ describe('Session', () => { ); }); + afterEach(() => { + // Reset global runtime base dir state to prevent state leakage between tests + core.Storage.setRuntimeBaseDir(null); + // Clear session reference to allow garbage collection + session = undefined as unknown as Session; + mockChat = undefined as unknown as GeminiChat; + mockConfig = undefined as unknown as Config; + mockClient = undefined as unknown as AgentSideConnection; + mockSettings = undefined as unknown as LoadedSettings; + mockToolRegistry = undefined as unknown as { + getTool: ReturnType; + }; + vi.restoreAllMocks(); + vi.clearAllTimers(); + }); + describe('setMode', () => { it.each([ ['plan', ApprovalMode.PLAN], @@ -208,20 +241,20 @@ describe('Session', () => { const fileName = 'README.md'; const filePath = path.join(tempDir, fileName); + const readManyFilesSpy = vi + .spyOn(core, 'readManyFiles') + .mockResolvedValue({ + contentParts: 'file content', + files: [], + }); + try { await fs.writeFile(filePath, '# Test\n', 'utf8'); - const readManyFilesSpy = vi - .spyOn(core, 'readManyFiles') - .mockResolvedValue({ - contentParts: 'file content', - files: [], - }); - mockConfig.getTargetDir = vi.fn().mockReturnValue(tempDir); mockChat.sendMessageStream = vi .fn() - .mockResolvedValue((async function* () {})()); + .mockResolvedValue(createEmptyStream()); const promptRequest: PromptRequest = { sessionId: 'test-session-id', @@ -242,6 +275,7 @@ describe('Session', () => { signal: expect.any(AbortSignal), }); } finally { + readManyFilesSpy.mockRestore(); await fs.rm(tempDir, { recursive: true, force: true }); } }); @@ -261,22 +295,26 @@ describe('Session', () => { 'runWithRuntimeBaseDir', ); - mockChat.sendMessageStream = vi - .fn() - .mockResolvedValue((async function* () {})()); + try { + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); - const promptRequest: PromptRequest = { - sessionId: 'test-session-id', - prompt: [{ type: 'text', text: 'hello' }], - }; + const promptRequest: PromptRequest = { + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }; - await session.prompt(promptRequest); + await session.prompt(promptRequest); - expect(runWithRuntimeBaseDirSpy).toHaveBeenCalledWith( - runtimeDir, - process.cwd(), - expect.any(Function), - ); + expect(runWithRuntimeBaseDirSpy).toHaveBeenCalledWith( + runtimeDir, + process.cwd(), + expect.any(Function), + ); + } finally { + runWithRuntimeBaseDirSpy.mockRestore(); + } }); it('hides allow-always options when confirmation already forbids them', async () => { @@ -311,8 +349,8 @@ describe('Session', () => { .mockReturnValue(ApprovalMode.DEFAULT); mockConfig.getPermissionManager = vi.fn().mockReturnValue(null); mockChat.sendMessageStream = vi.fn().mockResolvedValue( - (async function* () { - yield { + createStreamWithChunks([ + { type: core.StreamEventType.CHUNK, value: { functionCalls: [ @@ -323,8 +361,8 @@ describe('Session', () => { }, ], }, - }; - })(), + }, + ]), ); await session.prompt({ @@ -380,8 +418,8 @@ describe('Session', () => { mockConfig.getApprovalMode = vi.fn().mockReturnValue(ApprovalMode.PLAN); mockConfig.getPermissionManager = vi.fn().mockReturnValue(null); mockChat.sendMessageStream = vi.fn().mockResolvedValue( - (async function* () { - yield { + createStreamWithChunks([ + { type: core.StreamEventType.CHUNK, value: { functionCalls: [ @@ -395,8 +433,8 @@ describe('Session', () => { }, ], }, - }; - })(), + }, + ]), ); await session.prompt({ @@ -442,8 +480,8 @@ describe('Session', () => { isToolEnabled: vi.fn().mockResolvedValue(false), }); mockChat.sendMessageStream = vi.fn().mockResolvedValue( - (async function* () { - yield { + createStreamWithChunks([ + { type: core.StreamEventType.CHUNK, value: { functionCalls: [ @@ -454,8 +492,8 @@ describe('Session', () => { }, ], }, - }; - })(), + }, + ]), ); await session.prompt({ @@ -510,8 +548,8 @@ describe('Session', () => { mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); mockConfig.getMessageBus = vi.fn().mockReturnValue({}); mockChat.sendMessageStream = vi.fn().mockResolvedValue( - (async function* () { - yield { + createStreamWithChunks([ + { type: core.StreamEventType.CHUNK, value: { functionCalls: [ @@ -522,8 +560,8 @@ describe('Session', () => { }, ], }, - }; - })(), + }, + ]), ); try { @@ -542,5 +580,482 @@ describe('Session', () => { expect(invocation.params).toEqual({ path: '/tmp/updated.txt' }); expect(executeSpy).toHaveBeenCalled(); }); + + describe('hooks', () => { + describe('UserPromptSubmit hook', () => { + it('fires UserPromptSubmit hook before sending prompt', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi.fn().mockReturnValue(true); + + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'response' }] } }], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(messageBus.request).toHaveBeenCalledWith( + expect.objectContaining({ + eventName: 'UserPromptSubmit', + input: { prompt: 'hello' }, + }), + expect.anything(), + ); + }); + + it('blocks prompt when UserPromptSubmit hook returns blocking decision', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: { decision: 'block', reason: 'Blocked by hook' }, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi.fn().mockReturnValue(true); + + mockChat.sendMessageStream = vi.fn(); + + const result = await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'blocked prompt' }], + }); + + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + expect(result.stopReason).toBe('end_turn'); + }); + }); + + describe('Stop hook', () => { + it('fires Stop hook after model response completes', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi.fn().mockReturnValue(true); + mockChat.getHistory = vi + .fn() + .mockReturnValue([ + { role: 'model', parts: [{ text: 'response text' }] }, + ]); + + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'response' }] } }], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(messageBus.request).toHaveBeenCalledWith( + expect.objectContaining({ + eventName: 'Stop', + input: expect.objectContaining({ + stop_hook_active: true, + last_assistant_message: 'response text', + }), + }), + expect.anything(), + ); + }); + }); + + describe('PreToolUse hook', () => { + it('fires PreToolUse hook before tool execution', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.getApprovalMode = vi + .fn() + .mockReturnValue(ApprovalMode.YOLO); + + const executeSpy = vi.fn().mockResolvedValue({ + llmContent: 'result', + returnDisplay: 'done', + }); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read the file' }], + }); + + expect(messageBus.request).toHaveBeenCalledWith( + expect.objectContaining({ + eventName: 'PreToolUse', + input: expect.objectContaining({ + tool_name: 'read_file', + tool_input: { path: '/tmp/test.txt' }, + }), + }), + expect.anything(), + ); + }); + + it('blocks tool execution when PreToolUse hook returns blocking decision', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: { decision: 'deny', reason: 'Tool blocked by hook' }, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.getApprovalMode = vi + .fn() + .mockReturnValue(ApprovalMode.YOLO); + + const executeSpy = vi.fn(); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read the file' }], + }); + + expect(executeSpy).not.toHaveBeenCalled(); + }); + }); + + describe('PostToolUse hook', () => { + it('fires PostToolUse hook after successful tool execution', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.getApprovalMode = vi + .fn() + .mockReturnValue(ApprovalMode.YOLO); + + const executeSpy = vi.fn().mockResolvedValue({ + llmContent: 'file contents', + returnDisplay: 'success', + }); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read the file' }], + }); + + expect(messageBus.request).toHaveBeenCalledWith( + expect.objectContaining({ + eventName: 'PostToolUse', + input: expect.objectContaining({ + tool_name: 'read_file', + tool_response: expect.objectContaining({ + llmContent: 'file contents', + returnDisplay: 'success', + }), + }), + }), + expect.anything(), + ); + }); + + it('stops execution when PostToolUse hook returns shouldStop', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: { shouldStop: true, reason: 'Stopping per hook request' }, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.getApprovalMode = vi + .fn() + .mockReturnValue(ApprovalMode.YOLO); + + const executeSpy = vi.fn().mockResolvedValue({ + llmContent: 'file contents', + returnDisplay: 'success', + }); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + + // Only one call expected since shouldStop prevents continuation + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read the file' }], + }); + + // Tool should have been executed + expect(executeSpy).toHaveBeenCalled(); + // PostToolUse hook should have been called + expect(messageBus.request).toHaveBeenCalledWith( + expect.objectContaining({ + eventName: 'PostToolUse', + }), + expect.anything(), + ); + }); + }); + + describe('PostToolUseFailure hook', () => { + it('fires PostToolUseFailure hook when tool execution fails', async () => { + const messageBus = { + request: vi.fn().mockResolvedValue({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.getApprovalMode = vi + .fn() + .mockReturnValue(ApprovalMode.YOLO); + + const executeSpy = vi + .fn() + .mockRejectedValue(new Error('Tool failed')); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + mockChat.sendMessageStream = vi.fn().mockResolvedValue( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read the file' }], + }); + + expect(messageBus.request).toHaveBeenCalledWith( + expect.objectContaining({ + eventName: 'PostToolUseFailure', + input: expect.objectContaining({ + tool_name: 'read_file', + error: 'Tool failed', + }), + }), + expect.anything(), + ); + }); + }); + + describe('StopFailure hook', () => { + it('fires StopFailure hook when API error occurs during sendMessageStream', async () => { + const mockFireStopFailureEvent = vi.fn().mockResolvedValue({ + success: true, + }); + mockConfig.getHookSystem = vi.fn().mockReturnValue({ + fireStopFailureEvent: mockFireStopFailureEvent, + }); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi.fn().mockReturnValue(true); + + // Simulate API error (rate limit) + const apiError = new Error('Rate limit exceeded') as Error & { + status: number; + }; + apiError.status = 429; + + mockChat.sendMessageStream = vi.fn().mockImplementation(async () => { + throw apiError; + }); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).rejects.toThrow(); + + // StopFailure hook should be called with rate_limit error type + expect(mockFireStopFailureEvent).toHaveBeenCalledWith( + 'rate_limit', + 'Rate limit exceeded', + ); + }); + + it('does not fire StopFailure hook when hooks are disabled', async () => { + const mockFireStopFailureEvent = vi.fn(); + mockConfig.getHookSystem = vi.fn().mockReturnValue({ + fireStopFailureEvent: mockFireStopFailureEvent, + }); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(true); + + const apiError = new Error('Rate limit exceeded') as Error & { + status: number; + }; + apiError.status = 429; + + mockChat.sendMessageStream = vi.fn().mockImplementation(async () => { + throw apiError; + }); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).rejects.toThrow(); + + expect(mockFireStopFailureEvent).not.toHaveBeenCalled(); + }); + }); + }); }); }); diff --git a/packages/cli/src/acp-integration/session/Session.ts b/packages/cli/src/acp-integration/session/Session.ts index 6b2afa3e0..1f25aff9e 100644 --- a/packages/cli/src/acp-integration/session/Session.ts +++ b/packages/cli/src/acp-integration/session/Session.ts @@ -17,6 +17,10 @@ import type { ToolResult, ChatRecord, AgentEventEmitter, + StopHookOutput, + HookExecutionRequest, + HookExecutionResponse, + MessageBus, } from '@qwen-code/qwen-code-core'; import { AuthType, @@ -40,9 +44,15 @@ import { evaluatePermissionRules, fireNotificationHook, firePermissionRequestHook, + firePreToolUseHook, + firePostToolUseHook, + firePostToolUseFailureHook, injectPermissionRulesIfMissing, NotificationType, persistPermissionOutcome, + createHookOutput, + generateToolUseId, + MessageBusType, } from '@qwen-code/qwen-code-core'; import { RequestError } from '@agentclientprotocol/sdk'; @@ -72,6 +82,7 @@ import { } from '../../nonInteractiveCliCommands.js'; import { isSlashCommand } from '../../ui/utils/commandUtils.js'; import { parseAcpModelOption } from '../../utils/acpModelUtils.js'; +import { classifyApiError } from '../../ui/hooks/useGeminiStream.js'; // Import modular session components import type { @@ -339,6 +350,52 @@ export class Session implements SessionContext { parts = await this.#resolvePrompt(params.prompt, pendingSend.signal); } + // Fire UserPromptSubmit hook through MessageBus (aligned with core path in client.ts) + const hooksEnabled = !this.config.getDisableAllHooks?.(); + const messageBus = this.config.getMessageBus?.(); + if ( + hooksEnabled && + messageBus && + this.config.hasHooksForEvent?.('UserPromptSubmit') + ) { + const response = await messageBus.request< + HookExecutionRequest, + HookExecutionResponse + >( + { + type: MessageBusType.HOOK_EXECUTION_REQUEST, + eventName: 'UserPromptSubmit', + input: { + prompt: promptText, + }, + signal: pendingSend.signal, + }, + MessageBusType.HOOK_EXECUTION_RESPONSE, + ); + const hookOutput = response.output + ? createHookOutput('UserPromptSubmit', response.output) + : undefined; + + if ( + hookOutput?.isBlockingDecision() || + hookOutput?.shouldStopExecution() + ) { + // Hook blocked the prompt - send notification to UI and return + const blockReason = + hookOutput?.getEffectiveReason() || 'No reason provided'; + await this.messageEmitter.emitAgentMessage( + `🚫 **UserPromptSubmit blocked**: ${blockReason}`, + ); + return { stopReason: 'end_turn' }; + } + + // Add additional context from hooks to the request + const additionalContext = hookOutput?.getAdditionalContext(); + if (additionalContext) { + parts = [...parts, { text: additionalContext }]; + } + } + let nextMessage: Content | null = { role: 'user', parts }; while (nextMessage !== null) { @@ -403,7 +460,33 @@ export class Session implements SessionContext { } } } catch (error) { - if (getErrorStatus(error) === 429) { + // Fire StopFailure hook (fire-and-forget, replaces Stop event for API errors) + // Aligned with useGeminiStream.ts handleFinishedWithErrorEvent + const errorStatus = getErrorStatus(error); + const errorMessage = + error instanceof Error ? error.message : String(error); + const errorType = classifyApiError({ + message: errorMessage, + status: errorStatus, + }); + + const hookSystem = this.config.getHookSystem?.(); + const hooksEnabledForStopFailure = + !this.config.getDisableAllHooks?.(); + if ( + hooksEnabledForStopFailure && + hookSystem && + this.config.hasHooksForEvent?.('StopFailure') + ) { + // Fire-and-forget: don't wait for hook to complete + hookSystem + .fireStopFailureEvent(errorType, errorMessage) + .catch((err) => { + debugLogger.warn(`StopFailure hook failed: ${err}`); + }); + } + + if (errorStatus === 429) { throw new RequestError( 429, 'Rate limit exceeded. Try again later.', @@ -442,15 +525,266 @@ export class Session implements SessionContext { nextMessage = { role: 'user', parts: toolResponseParts }; } } + // Wait for any pending rewrite before returning if (this.messageRewriter) { await this.messageRewriter.waitForPendingRewrites(); } - return { stopReason: 'end_turn' }; + + // Fire Stop hook loop (aligned with core path in client.ts) + // This is triggered after model response completes with no pending tool calls + return this.#handleStopHookLoop( + chat, + pendingSend, + promptId, + hooksEnabled, + messageBus, + ); }, ); } + /** + * Handles the Stop hook iteration loop. + * This method processes Stop hooks after a model response completes with no pending tool calls. + * If a Stop hook requests continuation, it sends a follow-up message and loops back. + * Maximum iterations (100) prevent infinite loops. + * + * @param chat - The GeminiChat instance + * @param pendingSend - The abort controller for the current prompt + * @param promptId - The prompt ID for tracking + * @param hooksEnabled - Whether hooks are enabled + * @param messageBus - The MessageBus for hook communication (may be undefined) + * @returns The stop reason ('end_turn' or 'cancelled') + */ + async #handleStopHookLoop( + chat: GeminiChat, + pendingSend: AbortController, + promptId: string, + hooksEnabled: boolean, + messageBus: MessageBus | undefined, + ): Promise<{ stopReason: 'end_turn' | 'cancelled' }> { + const MAX_STOP_HOOK_ITERATIONS = 100; + let stopHookIterationCount = 0; + let stopHookReasons: string[] = []; + + while (stopHookIterationCount < MAX_STOP_HOOK_ITERATIONS) { + if ( + !hooksEnabled || + !messageBus || + pendingSend.signal.aborted || + !this.config.hasHooksForEvent?.('Stop') + ) { + return { stopReason: 'end_turn' }; + } + + // Get response text from the chat history + const history = chat.getHistory(); + const lastModelMessage = history + .filter((msg) => msg.role === 'model') + .pop(); + const responseText = + lastModelMessage?.parts + ?.filter((p): p is { text: string } => 'text' in p) + .map((p) => p.text) + .join('') || '[no response text]'; + + const response = await messageBus.request< + HookExecutionRequest, + HookExecutionResponse + >( + { + type: MessageBusType.HOOK_EXECUTION_REQUEST, + eventName: 'Stop', + input: { + stop_hook_active: true, + last_assistant_message: responseText, + }, + signal: pendingSend.signal, + }, + MessageBusType.HOOK_EXECUTION_RESPONSE, + ); + + // Check if aborted after hook execution + if (pendingSend.signal.aborted) { + return { stopReason: 'cancelled' }; + } + + const hookOutput = response.output + ? createHookOutput('Stop', response.output) + : undefined; + + const stopOutput = hookOutput as StopHookOutput | undefined; + + // Emit system message if provided by hook + if (stopOutput?.systemMessage) { + await this.messageEmitter.emitAgentMessage(stopOutput.systemMessage); + } + + // For Stop hooks, blocking/stop execution should force continuation + if ( + stopOutput?.isBlockingDecision() || + stopOutput?.shouldStopExecution() + ) { + const continueReason = stopOutput.getEffectiveReason(); + + // Track Stop hook iterations + stopHookIterationCount++; + stopHookReasons = [...stopHookReasons, continueReason]; + + // Emit StopHookLoop event for iterations after the first one + if (stopHookIterationCount > 1) { + await this.messageEmitter.emitStopHookLoop( + stopHookIterationCount, + stopHookReasons, + response.stopHookCount ?? 1, + ); + } + + // Continue the conversation with the hook's reason + const continueParts: Part[] = [{ text: continueReason }]; + let nextMessage: Content | null = { + role: 'user', + parts: continueParts, + }; + + // Process the follow-up message and any tool calls that result + while (nextMessage !== null) { + if (pendingSend.signal.aborted) { + return { stopReason: 'cancelled' }; + } + + const functionCalls: FunctionCall[] = []; + let usageMetadata: GenerateContentResponseUsageMetadata | null = null; + const streamStartTime = Date.now(); + + try { + const continueResponseStream = await chat.sendMessageStream( + this.config.getModel(), + { + message: nextMessage?.parts ?? [], + config: { + abortSignal: pendingSend.signal, + }, + }, + promptId + '_stop_hook_' + stopHookIterationCount, + ); + nextMessage = null; + + for await (const resp of continueResponseStream) { + if (pendingSend.signal.aborted) { + return { stopReason: 'cancelled' }; + } + + if ( + resp.type === StreamEventType.CHUNK && + resp.value.candidates && + resp.value.candidates.length > 0 + ) { + const candidate = resp.value.candidates[0]; + for (const part of candidate.content?.parts ?? []) { + if (!part.text) continue; + this.messageEmitter.emitMessage( + part.text, + 'assistant', + part.thought, + ); + } + } + + if ( + resp.type === StreamEventType.CHUNK && + resp.value.usageMetadata + ) { + usageMetadata = resp.value.usageMetadata; + } + + if ( + resp.type === StreamEventType.CHUNK && + resp.value.functionCalls + ) { + functionCalls.push(...resp.value.functionCalls); + } + } + } catch (error) { + // Fire StopFailure hook (fire-and-forget) + const errorStatus = getErrorStatus(error); + const errorMessage = + error instanceof Error ? error.message : String(error); + const errorType = classifyApiError({ + message: errorMessage, + status: errorStatus, + }); + + const hookSystem = this.config.getHookSystem?.(); + const hooksEnabledForStopFailure = + !this.config.getDisableAllHooks?.(); + if ( + hooksEnabledForStopFailure && + hookSystem && + this.config.hasHooksForEvent?.('StopFailure') + ) { + hookSystem + .fireStopFailureEvent(errorType, errorMessage) + .catch((err) => { + debugLogger.warn(`StopFailure hook failed: ${err}`); + }); + } + + if (errorStatus === 429) { + throw new RequestError( + 429, + 'Rate limit exceeded. Try again later.', + ); + } + + throw error; + } + + if (usageMetadata) { + const durationMs = Date.now() - streamStartTime; + await this.messageEmitter.emitUsageMetadata( + usageMetadata, + '', + durationMs, + ); + } + + // Process tool calls from the follow-up message + if (functionCalls.length > 0) { + const toolResponseParts: Part[] = []; + + for (const fc of functionCalls) { + const toolResponse = await this.runTool( + pendingSend.signal, + promptId, + fc, + ); + toolResponseParts.push(...toolResponse); + } + + nextMessage = { role: 'user', parts: toolResponseParts }; + } + } + + // Loop continues to check Stop hook again after processing the follow-up + continue; + } + + // Stop hook allowed stopping, exit the loop + break; + } + + // If we exceeded max iterations, log a warning but still end gracefully + if (stopHookIterationCount >= MAX_STOP_HOOK_ITERATIONS) { + debugLogger.warn( + `Stop hook loop reached maximum iterations (${MAX_STOP_HOOK_ITERATIONS}), forcing stop`, + ); + } + + return { stopReason: 'end_turn' }; + } + async sendUpdate(update: SessionUpdate): Promise { const params: SessionNotification = { sessionId: this.sessionId, @@ -841,6 +1175,12 @@ export class Session implements SessionContext { // Track cleanup functions for sub-agent event listeners let subAgentCleanupFunctions: Array<() => void> = []; + // Generate tool_use_id for hook tracking (aligned with core path) + const toolUseId = generateToolUseId(); + + // Get approval mode for hook context (defined outside try for catch block access) + const approvalMode = this.config.getApprovalMode(); + try { const invocation = tool.build(args); @@ -905,7 +1245,6 @@ export class Session implements SessionContext { const needsConfirmation = finalPermission === 'ask'; // ---- L5: ApprovalMode overrides ---- - const approvalMode = this.config.getApprovalMode(); const isPlanMode = approvalMode === ApprovalMode.PLAN; if (finalPermission === 'deny') { @@ -1107,6 +1446,41 @@ export class Session implements SessionContext { await this.toolCallEmitter.emitStart(startParams); } + // Fire PreToolUse hook (aligned with core path in coreToolScheduler.ts) + const hooksEnabledForTool = !this.config.getDisableAllHooks?.(); + const messageBusForTool = this.config.getMessageBus?.(); + const permissionMode = String(approvalMode); + + if (hooksEnabledForTool && messageBusForTool) { + const preHookResult = await firePreToolUseHook( + messageBusForTool, + fc.name, + args, + toolUseId, + permissionMode, + abortSignal, + ); + + if (!preHookResult.shouldProceed) { + // Hook blocked the tool execution - send notification to UI + const blockReason = + preHookResult.blockReason || 'Blocked by PreToolUse hook'; + await this.messageEmitter.emitAgentMessage( + `🚫 **PreToolUse blocked**: ${fc.name} - ${blockReason}`, + ); + return earlyErrorResponse(new Error(blockReason), fc.name); + } + + // Add additional context from PreToolUse hook if provided + // Note: This context would need to be passed to the tool invocation + // For now, we just log it as the tool execution proceeds + if (preHookResult.additionalContext) { + debugLogger.debug( + `PreToolUse hook additional context for ${fc.name}: ${preHookResult.additionalContext}`, + ); + } + } + const toolResult: ToolResult = await invocation.execute(abortSignal); // Clean up event listeners @@ -1119,6 +1493,61 @@ export class Session implements SessionContext { toolResult.llmContent, ); + // Fire PostToolUse hook on successful execution (aligned with core path) + if (hooksEnabledForTool && messageBusForTool && !toolResult.error) { + // Use the same response shape as core (llmContent/returnDisplay) + const toolResponse = { + llmContent: toolResult.llmContent, + returnDisplay: toolResult.returnDisplay, + }; + const postHookResult = await firePostToolUseHook( + messageBusForTool, + fc.name, + args, + toolResponse, + toolUseId, + permissionMode, + abortSignal, + ); + + // If hook indicates to stop, return an error response + if (postHookResult.shouldStop) { + const stopMessage = + postHookResult.stopReason || + 'Execution stopped by PostToolUse hook'; + debugLogger.info( + `PostToolUse hook requested stop for ${fc.name}: ${stopMessage}`, + ); + return earlyErrorResponse(new Error(stopMessage), fc.name); + } + + // Add additional context from PostToolUse hook if provided + if (postHookResult.additionalContext) { + // Append additional context to the tool response + const contextPart = { text: postHookResult.additionalContext }; + responseParts.push(contextPart); + } + } else if (hooksEnabledForTool && messageBusForTool && toolResult.error) { + // Fire PostToolUseFailure hook when tool returns an error (aligned with core path) + const failureHookResult = await firePostToolUseFailureHook( + messageBusForTool, + toolUseId, + fc.name ?? 'unknown_tool', + args, + toolResult.error.message, + false, // not an interrupt + permissionMode, + abortSignal, + ); + + // Log additional context if provided + if (failureHookResult.additionalContext) { + debugLogger.debug( + `PostToolUseFailure hook additional context for ${fc.name}: ${failureHookResult.additionalContext}`, + ); + } + } + // Handle TodoWriteTool: extract todos and send plan update if (isTodoWriteTool) { const todos = this.planEmitter.extractTodos( @@ -1183,6 +1612,31 @@ export class Session implements SessionContext { const error = e instanceof Error ? e : new Error(String(e)); + // Fire PostToolUseFailure hook (aligned with core path in coreToolScheduler.ts) + const hooksEnabledForError = !this.config.getDisableAllHooks?.(); + const messageBusForError = this.config.getMessageBus?.(); + const isInterrupt = abortSignal.aborted; + + if (hooksEnabledForError && messageBusForError) { + const failureHookResult = await firePostToolUseFailureHook( + messageBusForError, + toolUseId, + fc.name ?? 'unknown_tool', + args, + error.message, + isInterrupt, + String(approvalMode), + abortSignal, + ); + + // Log additional context if provided + if (failureHookResult.additionalContext) { + debugLogger.debug( + `PostToolUseFailure hook additional context for ${fc.name}: ${failureHookResult.additionalContext}`, + ); + } + } + // Use ToolCallEmitter for error handling await this.toolCallEmitter.emitError( callId, diff --git a/packages/cli/src/acp-integration/session/emitters/MessageEmitter.ts b/packages/cli/src/acp-integration/session/emitters/MessageEmitter.ts index c4e0b971c..3a92c1131 100644 --- a/packages/cli/src/acp-integration/session/emitters/MessageEmitter.ts +++ b/packages/cli/src/acp-integration/session/emitters/MessageEmitter.ts @@ -17,6 +17,31 @@ import { BaseEmitter } from './BaseEmitter.js'; * normal flow, history replay, or other sources. */ export class MessageEmitter extends BaseEmitter { + /** + * Emits a StopHookLoop event when Stop hooks create a loop. + * This informs the client that Stop hooks have been executed multiple times. + * + * @param iterationCount - The current iteration count + * @param reasons - Array of reasons from each Stop hook execution + * @param stopHookCount - Number of Stop hooks that were executed + */ + async emitStopHookLoop( + iterationCount: number, + reasons: string[], + stopHookCount: number, + ): Promise { + await this.sendUpdate({ + sessionUpdate: 'agent_message_chunk', + content: { type: 'text', text: '' }, + _meta: { + stopHookLoop: { + iterationCount, + reasons, + stopHookCount, + }, + }, + }); + } /** * Emits a user message chunk. * diff --git a/packages/cli/src/utils/startupProfiler.test.ts b/packages/cli/src/utils/startupProfiler.test.ts index 0fb4bebfe..adfb4bf99 100644 --- a/packages/cli/src/utils/startupProfiler.test.ts +++ b/packages/cli/src/utils/startupProfiler.test.ts @@ -112,9 +112,7 @@ describe('startupProfiler', () => { it('should write JSON file on finalize and print path to stderr', () => { vi.mocked(fs.mkdirSync).mockReturnValue(undefined); vi.mocked(fs.writeFileSync).mockReturnValue(undefined); - const stderrSpy = vi - .spyOn(process.stderr, 'write') - .mockReturnValue(true); + const stderrSpy = vi.spyOn(process.stderr, 'write').mockReturnValue(true); initStartupProfiler(); profileCheckpoint('main_entry'); @@ -200,9 +198,7 @@ describe('startupProfiler', () => { vi.mocked(fs.mkdirSync).mockImplementation(() => { throw new Error('Permission denied'); }); - const stderrSpy = vi - .spyOn(process.stderr, 'write') - .mockReturnValue(true); + const stderrSpy = vi.spyOn(process.stderr, 'write').mockReturnValue(true); initStartupProfiler(); profileCheckpoint('test'); diff --git a/packages/cli/src/utils/startupProfiler.ts b/packages/cli/src/utils/startupProfiler.ts index 3c3cb42f5..d9884bf84 100644 --- a/packages/cli/src/utils/startupProfiler.ts +++ b/packages/cli/src/utils/startupProfiler.ts @@ -122,9 +122,7 @@ export function finalizeStartupProfile(sessionId?: string): void { fs.writeFileSync(filepath, JSON.stringify(report, null, 2), 'utf-8'); process.stderr.write(`Startup profile written to: ${filepath}\n`); } catch { - process.stderr.write( - 'Warning: Failed to write startup profile report\n', - ); + process.stderr.write('Warning: Failed to write startup profile report\n'); } } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 2708890b6..a806112cc 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -258,6 +258,17 @@ export * from './utils/yaml-parser.js'; export * from './qwen/qwenOAuth2.js'; +// ============================================================================ +// Message Bus Types +// ============================================================================ + +export { + MessageBusType, + type HookExecutionRequest, + type HookExecutionResponse, +} from './confirmation-bus/types.js'; +export { MessageBus } from './confirmation-bus/message-bus.js'; + // ============================================================================ // Testing Utilities // ============================================================================ @@ -272,10 +283,19 @@ export * from './test-utils/index.js'; export * from './hooks/types.js'; export { HookSystem, HookRegistry } from './hooks/index.js'; export type { HookRegistryEntry } from './hooks/index.js'; +export { type StopFailureErrorType } from './hooks/types.js'; -// Export hook triggers for notification hooks +// Export hook triggers for all hook events export { fireNotificationHook, firePermissionRequestHook, + firePreToolUseHook, + firePostToolUseHook, + firePostToolUseFailureHook, type NotificationHookResult, + type PermissionRequestHookResult, + type PreToolUseHookResult, + type PostToolUseHookResult, + type PostToolUseFailureHookResult, + generateToolUseId, } from './core/toolHookTriggers.js';