diff --git a/packages/channels/weixin/src/WeixinAdapter.ts b/packages/channels/weixin/src/WeixinAdapter.ts index 561a366f4..0660c9a88 100644 --- a/packages/channels/weixin/src/WeixinAdapter.ts +++ b/packages/channels/weixin/src/WeixinAdapter.ts @@ -17,14 +17,19 @@ import type { import { loadAccount, DEFAULT_BASE_URL } from './accounts.js'; import { startPollLoop, getContextToken } from './monitor.js'; import type { CdnRef, FileCdnRef } from './monitor.js'; -import { sendText } from './send.js'; +import { sendText, sendImage, detectImageMime } from './send.js'; import { downloadAndDecrypt } from './media.js'; -import { getConfig, sendTyping } from './api.js'; +import { getConfig, sendTyping, WeixinApiError } from './api.js'; import { TypingStatus } from './types.js'; /** In-memory typing ticket cache: userId -> typingTicket */ const typingTickets = new Map(); +/** Escape special regex characters in a string. */ +function escapeRegex(s: string): string { + return s.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); +} + export class WeixinChannel extends ChannelBase { private abortController: AbortController | null = null; private baseUrl: string; @@ -43,6 +48,35 @@ export class WeixinChannel extends ChannelBase { } async connect(): Promise { + // Default channel instructions — always include image capability info + const imageInstructions = [ + '', + 'If you created an image file (screenshot, chart, etc.), you can send it to the user by writing:', + '[IMAGE: /absolute/path/to/file.png]', + '', + 'The marker is stripped from text and the image is uploaded automatically.', + '', + 'CRITICAL: Only use real file paths. Do NOT write [IMAGE: ...] with:', + '- Example paths like /path/to/file or /tmp/cat.png', + '- Placeholder symbols like ...', + "- Paths that don't exist on disk", + ].join('\n'); + + if (!this.config.instructions) { + this.config.instructions = [ + '## WeChat Channel', + '', + 'You are a concise coding assistant responding via WeChat.', + 'Keep responses under 500 characters. Use plain text only.', + '', + 'Users can also send you images.', + imageInstructions, + ].join('\n'); + } else if (!this.config.instructions.includes('[IMAGE:')) { + // Use a local copy to avoid mutating this.config.instructions on reconnect. + this.config.instructions = + this.config.instructions + '\n' + imageInstructions; + } const account = loadAccount(); if (!account) { throw new Error( @@ -158,13 +192,84 @@ export class WeixinChannel extends ChannelBase { async sendMessage(chatId: string, text: string): Promise { const contextToken = getContextToken(chatId) || ''; - await sendText({ - to: chatId, - text, - baseUrl: this.baseUrl, - token: this.token, - contextToken, - }); + + // Parse [IMAGE: /path/to/file.png] markers from text. + // Strip code blocks first to avoid matching example syntax inside them. + const textWithoutCode = text + .replace(/```[\s\S]*?```/g, '') + .replace(/`[^`]*`/g, ''); + + // Extract image paths from code-free text. + const imageRegex = /\[IMAGE:\s*([^\]]+)\]/gi; + const parsedImages: string[] = []; + for (const m of textWithoutCode.matchAll(imageRegex)) { + const trimmed = m[1]?.trim(); + if (trimmed) parsedImages.push(trimmed); + } + + // Only strip markers that were actually parsed (avoids silently + // removing [IMAGE:] inside code blocks from the displayed text). + let cleanedText = text; + for (const path of parsedImages) { + cleanedText = cleanedText.replace( + new RegExp(`\\[IMAGE:\\s*${escapeRegex(path)}\\]`, 'gi'), + '', + ); + } + + // Clean up double blank lines left by removed markers + cleanedText = cleanedText.replace(/\n{3,}/g, '\n\n').trim(); + + // Send text first if non-empty + if (cleanedText) { + await sendText({ + to: chatId, + text: cleanedText, + baseUrl: this.baseUrl, + token: this.token, + contextToken, + }); + } + + // Send images + if (parsedImages.length) { + const workspaceDirs = [this.config.cwd]; + + for (const imagePath of parsedImages) { + try { + await sendImage({ + to: chatId, + imagePath, + baseUrl: this.baseUrl, + token: this.token, + contextToken, + workspaceDirs, + }); + } catch (err) { + const status = err instanceof WeixinApiError ? err.status : 0; + const ret = err instanceof WeixinApiError ? err.ret : undefined; + const errcode = + err instanceof WeixinApiError ? err.errcode : undefined; + const msg = err instanceof Error ? err.message : String(err); + process.stderr.write( + `[Weixin:${this.name}] Failed to send image (status=${status} ret=${ret} errcode=${errcode}): ${msg}\n`, + ); + try { + await sendText({ + to: chatId, + text: '图片发送失败,请稍后重试', + baseUrl: this.baseUrl, + token: this.token, + contextToken, + }); + } catch (fallbackErr) { + process.stderr.write( + `[Weixin:${this.name}] Fallback text also failed: ${fallbackErr instanceof Error ? fallbackErr.message : String(fallbackErr)}\n`, + ); + } + } + } + } } disconnect(): void { @@ -202,27 +307,3 @@ export class WeixinChannel extends ChannelBase { } } } - -/** Detect image MIME type from magic bytes. */ -function detectImageMime(data: Buffer): string { - if ( - data[0] === 0x89 && - data[1] === 0x50 && - data[2] === 0x4e && - data[3] === 0x47 - ) { - return 'image/png'; - } - if (data[0] === 0x47 && data[1] === 0x49 && data[2] === 0x46) { - return 'image/gif'; - } - if ( - data[0] === 0x52 && - data[1] === 0x49 && - data[2] === 0x46 && - data[3] === 0x46 - ) { - return 'image/webp'; - } - return 'image/jpeg'; -} diff --git a/packages/channels/weixin/src/api.ts b/packages/channels/weixin/src/api.ts index 93ccbf4b6..5097b978e 100644 --- a/packages/channels/weixin/src/api.ts +++ b/packages/channels/weixin/src/api.ts @@ -12,6 +12,71 @@ import type { BaseInfo, } from './types.js'; +// ── Error handling ──────────────────────────────────────────────── + +/** Structured error from WeChat iLink Bot API. */ +export class WeixinApiError extends Error { + /** HTTP status code (0 if network/timeout error). */ + status: number; + /** API-level return code (ret field in response body). */ + ret?: number; + /** API-level error code (errcode field in response body). */ + errcode?: number; + + constructor(message: string, status: number, ret?: number, errcode?: number) { + super(message); + this.name = 'WeixinApiError'; + this.status = status; + this.ret = ret; + this.errcode = errcode; + } +} + +/** Errors that are safe to retry (transient / network). */ +function isRetryableError(err: unknown): boolean { + if (err instanceof WeixinApiError) { + // Session expired — not retryable (needs re-login) + if (err.errcode === -14) return false; + // API-level transient errors (system busy, rate limit) + if (err.errcode === -1 || err.errcode === 45011) return true; + // ret field is used by getUploadUrl and other endpoints + if (err.ret !== undefined && err.ret !== 0) return false; + // Client errors (4xx except 429) — not retryable + if (err.status >= 400 && err.status < 500) return err.status === 429; + // Server errors (5xx) or network errors (status 0) — retryable + return err.status === 0 || err.status >= 500; + } + if (err instanceof TypeError || (err as NodeJS.ErrnoException).code) { + // Network errors (fetch TypeError, ECONNRESET, ETIMEDOUT, etc.) + return true; + } + return false; +} + +/** Exponential backoff retry wrapper. */ +async function retryWithBackoff( + fn: (attempt: number) => Promise, + maxRetries = 3, + baseDelayMs = 1000, +): Promise { + let lastError: unknown; + + for (let attempt = 1; attempt <= maxRetries + 1; attempt++) { + try { + return await fn(attempt); + } catch (err: unknown) { + lastError = err; + if (attempt > maxRetries || !isRetryableError(err)) { + throw err; + } + const delay = baseDelayMs * Math.pow(2, attempt - 1); + await new Promise((r) => setTimeout(r, delay)); + } + } + + throw lastError; +} + // iLink Bot API protocol version we are compatible with. // Used both in the request body (base_info.channel_version) and in the // iLink-App-ClientVersion header (encoded as 0x00MMNNPP). @@ -74,7 +139,26 @@ async function post( signal: controller.signal, }); if (!resp.ok) { - throw new Error(`HTTP ${resp.status}: ${resp.statusText}`); + // Try to parse the API error body for ret/errcode/errmsg + let ret: number | undefined; + let errcode: number | undefined; + let errmsg: string | undefined; + try { + const errBody = (await resp.json()) as { + ret?: number; + errcode?: number; + errmsg?: string; + }; + ret = errBody.ret; + errcode = errBody.errcode; + errmsg = errBody.errmsg; + } catch { + // ignore parse errors — use status-based message + } + const message = errmsg + ? `WeChat API error (HTTP ${resp.status}, ret=${ret}, errcode=${errcode}): ${errmsg}` + : `WeChat API error (HTTP ${resp.status})`; + throw new WeixinApiError(message, resp.status, ret, errcode); } return (await resp.json()) as T; } finally { @@ -116,7 +200,25 @@ export async function sendMessage( msg: SendMessageReq['msg'], ): Promise { const body: SendMessageReq = { msg, base_info: baseInfo() }; - await post(baseUrl, '/ilink/bot/sendmessage', body, token); + + await retryWithBackoff(async (_attempt) => { + const resp = await post<{ + ret?: number; + errcode?: number; + errmsg?: string; + }>(baseUrl, '/ilink/bot/sendmessage', body, token); + if ( + (resp.ret !== undefined && resp.ret !== 0) || + (resp.errcode !== undefined && resp.errcode !== 0) + ) { + throw new WeixinApiError( + `sendMessage failed: ret=${resp.ret} errcode=${resp.errcode} ${resp.errmsg || ''}`, + 200, + resp.ret, + resp.errcode, + ); + } + }); } export async function getConfig( @@ -141,3 +243,166 @@ export async function sendTyping( const body: SendTypingReq = { ...req, base_info: baseInfo() }; return post(baseUrl, '/ilink/bot/sendtyping', body, token); } + +interface GetUploadUrlReq { + filekey: string; + media_type: number; + to_user_id: string; + rawsize: number; + rawfilemd5: string; + filesize: number; + no_need_thumb: boolean; + aeskey: string; + base_info: BaseInfo; +} + +interface GetUploadUrlResp { + ret?: number; + errcode?: number; + errmsg?: string; + upload_full_url?: string; + upload_param?: string; + thumb_upload_param?: string; +} + +/** + * Request an upload URL and CDN credentials for media. + * @param aeskeyHex 16-byte AES key as 32-char hex string (e.g. "00112233445566778899aabbccddeeff") + * @returns Either the full CDN upload URL or the upload_param string + */ +export async function getUploadUrl( + baseUrl: string, + token: string, + toUserId: string, + filekey: string, + rawsize: number, + rawfilemd5: string, + encryptedSize: number, + aeskeyHex: string, +): Promise { + const body: GetUploadUrlReq = { + filekey, + media_type: 1, + to_user_id: toUserId, + rawsize, + rawfilemd5, + filesize: encryptedSize, + no_need_thumb: true, + aeskey: aeskeyHex, + base_info: baseInfo(), + }; + + return retryWithBackoff(async (_attempt) => { + const resp = await post( + baseUrl, + '/ilink/bot/getuploadurl', + body, + token, + ); + + // Check API-level error first + if ( + (resp.ret !== undefined && resp.ret !== 0) || + (resp.errcode !== undefined && resp.errcode !== 0) + ) { + throw new WeixinApiError( + `getuploadurl failed: ret=${resp.ret} errcode=${resp.errcode ?? '(none)'} errmsg=${resp.errmsg || '(none)'}`, + 200, + resp.ret, + resp.errcode, + ); + } + + // upload_full_url: CDN upload URL with all params embedded + if (resp.upload_full_url) { + return resp.upload_full_url; + } + + // upload_param: CDN upload params only (must construct URL with filekey) + if (resp.upload_param) { + return resp.upload_param; + } + + throw new WeixinApiError( + `getuploadurl returned no URL: ret=${resp.ret} errcode=${resp.errcode ?? '(none)'} errmsg=${resp.errmsg || '(none)'}`, + 200, + resp.ret, + resp.errcode, + ); + }); +} + +/** Upload encrypted media to CDN. + * If urlOrParam is a full URL, use it directly (host must match). + * If it's just a param, construct the URL. */ +export async function uploadToCdn( + urlOrParam: string, + filekey: string, + encryptedData: Buffer, +): Promise { + const CDN_HOST = 'novac2c.cdn.weixin.qq.com'; + + let url: string; + if (urlOrParam.startsWith('https://')) { + const parsed = new URL(urlOrParam); + if (parsed.hostname !== CDN_HOST) { + throw new Error(`CDN upload URL has unexpected host: ${parsed.hostname}`); + } + url = urlOrParam; + } else if (urlOrParam.startsWith('http://')) { + throw new Error('CDN upload URL must use HTTPS'); + } else { + url = `https://${CDN_HOST}/c2c/upload?encrypted_query_param=${encodeURIComponent(urlOrParam)}&filekey=${encodeURIComponent(filekey)}`; + } + + return retryWithBackoff(async (_attempt) => { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 40000); + + try { + const resp = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/octet-stream' }, + body: encryptedData, + signal: controller.signal, + }); + if (!resp.ok) { + // Try to extract error details from CDN response + let cdnErrMsg: string | undefined; + let cdnRet: number | undefined; + let cdnErrCode: number | undefined; + try { + const errBody = (await resp.json()) as { + errmsg?: string; + ret?: number; + errcode?: number; + }; + cdnErrMsg = errBody.errmsg; + cdnRet = errBody.ret; + cdnErrCode = errBody.errcode; + } catch { + // ignore + } + throw new WeixinApiError( + cdnErrMsg + ? `CDN upload failed: HTTP ${resp.status} — ${cdnErrMsg}` + : `CDN upload failed: HTTP ${resp.status}`, + resp.status, + cdnRet, + cdnErrCode, + ); + } + // Extract x-encrypted-param from response header + const encryptParam = resp.headers.get('x-encrypted-param'); + if (!encryptParam) { + throw new WeixinApiError( + 'CDN upload succeeded but missing x-encrypted-param header', + resp.status, + ); + } + return encryptParam; + } finally { + clearTimeout(timeout); + } + }); +} diff --git a/packages/channels/weixin/src/media.test.ts b/packages/channels/weixin/src/media.test.ts index 745c01554..28b42467a 100644 --- a/packages/channels/weixin/src/media.test.ts +++ b/packages/channels/weixin/src/media.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect } from 'vitest'; import { createDecipheriv, createCipheriv } from 'node:crypto'; +import { encryptAesEcb, computeMd5 } from './media.js'; /** * Test the AES key parsing and decryption logic used in media.ts. @@ -89,3 +90,53 @@ describe('Weixin media crypto', () => { }); }); }); + +describe('encryptAesEcb', () => { + it('encrypts data deterministically', () => { + const key = Buffer.alloc(16, 0xab); + const plaintext = Buffer.from('test data for encryption'); + const ciphertext1 = encryptAesEcb(plaintext, key); + const ciphertext2 = encryptAesEcb(plaintext, key); + expect(ciphertext1).toEqual(ciphertext2); + }); + + it('encrypts then decrypts round-trip', () => { + const key = Buffer.alloc(16, 0x42); + const plaintext = Buffer.from('Hello, WeChat media upload!'); + + const ciphertext = encryptAesEcb(plaintext, key); + const decipher = createDecipheriv('aes-128-ecb', key, null); + const decrypted = Buffer.concat([ + decipher.update(ciphertext), + decipher.final(), + ]); + expect(decrypted.toString()).toBe(plaintext.toString()); + }); + + it('handles empty plaintext', () => { + const key = Buffer.alloc(16, 0x01); + const ciphertext = encryptAesEcb(Buffer.alloc(0), key); + // ECB with empty input produces empty output (no padding block needed + // when input is exactly 0 bytes — behavior varies by implementation) + // At minimum the result should be decryptable + const decipher = createDecipheriv('aes-128-ecb', key, null); + const decrypted = Buffer.concat([ + decipher.update(ciphertext), + decipher.final(), + ]); + expect(decrypted.length).toBe(0); + }); +}); + +describe('computeMd5', () => { + it('computes expected MD5', () => { + const data = Buffer.from('hello world'); + expect(computeMd5(data)).toBe('5eb63bbbe01eeed093cb22bb8f5acdc3'); + }); + + it('computes MD5 of empty buffer', () => { + expect(computeMd5(Buffer.alloc(0))).toBe( + 'd41d8cd98f00b204e9800998ecf8427e', + ); + }); +}); diff --git a/packages/channels/weixin/src/media.ts b/packages/channels/weixin/src/media.ts index 8cd7fa9eb..93dcd35fb 100644 --- a/packages/channels/weixin/src/media.ts +++ b/packages/channels/weixin/src/media.ts @@ -3,7 +3,7 @@ * Ported from cc-weixin/plugins/weixin/src/media.ts (download path only). */ -import { createDecipheriv } from 'node:crypto'; +import { createCipheriv, createDecipheriv, createHash } from 'node:crypto'; const CDN_BASE_URL = 'https://novac2c.cdn.weixin.qq.com/c2c'; @@ -22,7 +22,7 @@ function decryptAesEcb(ciphertext: Buffer, key: Buffer): Buffer { * - base64(raw 16 bytes) → images * - base64(hex string of 16 bytes) → file/voice/video */ -function parseAesKey(aesKeyBase64: string): Buffer { +export function parseAesKey(aesKeyBase64: string): Buffer { const decoded = Buffer.from(aesKeyBase64, 'base64'); if (decoded.length === 16) { return decoded; @@ -45,12 +45,30 @@ export async function downloadAndDecrypt( ): Promise { const url = buildCdnDownloadUrl(encryptQueryParam); - const resp = await fetch(url); - if (!resp.ok) { - throw new Error(`CDN download failed: HTTP ${resp.status}`); - } + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 40000); - const ciphertext = Buffer.from(await resp.arrayBuffer()); - const keyBuf = parseAesKey(aesKey); - return decryptAesEcb(ciphertext, keyBuf); + try { + const resp = await fetch(url, { signal: controller.signal }); + if (!resp.ok) { + throw new Error(`CDN download failed: HTTP ${resp.status}`); + } + + const ciphertext = Buffer.from(await resp.arrayBuffer()); + const keyBuf = parseAesKey(aesKey); + return decryptAesEcb(ciphertext, keyBuf); + } finally { + clearTimeout(timeout); + } +} + +/** AES-128-ECB encryption for CDN upload. */ +export function encryptAesEcb(plaintext: Buffer, key: Buffer): Buffer { + const cipher = createCipheriv('aes-128-ecb', key, null); + return Buffer.concat([cipher.update(plaintext), cipher.final()]); +} + +/** Compute MD5 hash of a buffer, returning hex string. */ +export function computeMd5(data: Buffer): string { + return createHash('md5').update(data).digest('hex'); } diff --git a/packages/channels/weixin/src/send.test.ts b/packages/channels/weixin/src/send.test.ts index 95152672c..3d3c275c8 100644 --- a/packages/channels/weixin/src/send.test.ts +++ b/packages/channels/weixin/src/send.test.ts @@ -1,6 +1,78 @@ -import { describe, it, expect } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import * as fs from 'node:fs'; import { markdownToPlainText } from './send.js'; +const { + mockReadFileSync, + mockStatSync, + mockRealpathSync, + mockGetUploadUrl, + mockUploadToCdn, + mockSendMessage, + mockRandomBytes, +} = vi.hoisted(() => ({ + mockReadFileSync: vi.fn(), + mockStatSync: vi.fn(), + mockRealpathSync: vi.fn((p: string) => p), + mockGetUploadUrl: vi.fn(), + mockUploadToCdn: vi.fn(), + mockSendMessage: vi.fn(), + mockRandomBytes: vi.fn((size: number) => Buffer.alloc(size, 0x42)), +})); + +// PNG magic bytes: 89 50 4E 47 0D 0A 1A 0A +const PNG_HEADER = Buffer.from([ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, +]); + +vi.mock('node:os', () => ({ + tmpdir: () => '/tmp', +})); + +vi.mock('node:fs', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + readFileSync: mockReadFileSync, + statSync: mockStatSync, + realpathSync: mockRealpathSync, + openSync: vi.fn(() => 42), + readSync: vi.fn((_fd: number, buf: Buffer) => { + PNG_HEADER.copy(buf); + return PNG_HEADER.length; + }), + closeSync: vi.fn(), + }; +}); + +vi.mock('node:path', async (importOriginal) => { + const actual = await importOriginal(); + return { ...actual }; +}); + +vi.mock('node:crypto', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + randomBytes: mockRandomBytes, + randomUUID: () => 'test-uuid', + }; +}); + +vi.mock('./api.js', () => ({ + sendMessage: mockSendMessage, + getUploadUrl: mockGetUploadUrl, + uploadToCdn: mockUploadToCdn, +})); + +// Use real encryptAesEcb / computeMd5 so tests catch padding mismatches. +const { encryptAesEcb, computeMd5 } = + await vi.importActual('./media.js'); + +const { sendImage, detectImageMime, validateImagePath } = await import( + './send.js' +); + describe('markdownToPlainText', () => { it('strips code blocks', () => { const input = '```js\nconst x = 1;\n```'; @@ -80,3 +152,218 @@ describe('markdownToPlainText', () => { expect(result).not.toContain('`'); }); }); + +describe('detectImageMime', () => { + it('detects PNG magic bytes', () => { + const buf = Buffer.from([0x89, 0x50, 0x4e, 0x47]); + expect(detectImageMime(buf)).toBe('image/png'); + }); + + it('detects GIF magic bytes', () => { + const buf = Buffer.from([0x47, 0x49, 0x46]); + expect(detectImageMime(buf)).toBe('image/gif'); + }); + + it('detects WebP magic bytes (RIFF)', () => { + const buf = Buffer.from([0x52, 0x49, 0x46, 0x46]); + expect(detectImageMime(buf)).toBe('image/webp'); + }); + + it('detects JPEG magic bytes', () => { + const buf = Buffer.from([0xff, 0xd8, 0xff]); + expect(detectImageMime(buf)).toBe('image/jpeg'); + }); + + it('throws for unrecognized magic bytes', () => { + const buf = Buffer.from([0x00, 0x00, 0x00, 0x00]); + expect(() => detectImageMime(buf)).toThrow('Unrecognized image format'); + }); +}); + +describe('validateImagePath', () => { + const workspaceDirs = ['/home/user/project']; + + beforeEach(() => { + vi.clearAllMocks(); + // Restore default mock behaviour: identity pass-through for realpath, + // regular file, small size, PNG magic in readSync. + mockRealpathSync.mockImplementation((p: string) => p); + mockStatSync.mockReturnValue({ + isFile: () => true, + size: 100, + } as unknown as ReturnType<(typeof fs)['statSync']>); + vi.mocked(fs.readSync).mockImplementation((_fd: number, buf: Buffer) => { + PNG_HEADER.copy(buf); + return PNG_HEADER.length; + }); + }); + + it('rejects disallowed extensions', () => { + expect(() => + validateImagePath('/tmp/screenshot.txt', workspaceDirs), + ).toThrow('Image extension not allowed'); + }); + + it('rejects non-existent files', () => { + mockRealpathSync.mockImplementation(() => { + throw new Error('ENOENT: no such file'); + }); + expect(() => validateImagePath('/tmp/missing.png', workspaceDirs)).toThrow( + 'Image file not found', + ); + }); + + it('rejects non-regular files (directories etc.)', () => { + mockStatSync.mockReturnValue({ + isFile: () => false, + size: 0, + } as unknown as ReturnType<(typeof fs)['statSync']>); + expect(() => validateImagePath('/tmp/some-dir.png', workspaceDirs)).toThrow( + 'Not a regular file', + ); + }); + + it('rejects files exceeding 20 MB cap', () => { + mockStatSync.mockReturnValue({ + isFile: () => true, + size: 21 * 1024 * 1024, + } as unknown as ReturnType<(typeof fs)['statSync']>); + expect(() => validateImagePath('/tmp/huge.png', workspaceDirs)).toThrow( + 'Image too large', + ); + }); + + it('rejects paths outside allowed directories', () => { + mockRealpathSync.mockImplementation((p: string) => p); + expect(() => validateImagePath('/etc/passwd.png', workspaceDirs)).toThrow( + 'Image path outside allowed directories', + ); + }); + + it('rejects image with magic bytes that do not match extension', () => { + // readSync returns JPEG magic, but file extension is .png + vi.mocked(fs.readSync).mockImplementation((_fd: number, buf: Buffer) => { + const jpegMagic = Buffer.from([0xff, 0xd8, 0xff]); + jpegMagic.copy(buf); + return jpegMagic.length; + }); + expect(() => + validateImagePath('/tmp/actually-jpeg.png', workspaceDirs), + ).toThrow('Image type mismatch'); + }); + + it('returns resolved realpath on success', () => { + mockRealpathSync.mockImplementation((p: string) => `/private${p}`); + const result = validateImagePath('/tmp/photo.png', workspaceDirs); + expect(result).toBe('/private/tmp/photo.png'); + }); +}); + +describe('sendImage', () => { + const defaultParams = { + to: 'user-123', + imagePath: '/tmp/test.png', + baseUrl: 'https://api.example.com', + token: 'token-abc', + contextToken: 'ctx-456', + workspaceDirs: ['/home/user/project'], + }; + + const fakeImageData = Buffer.concat([ + PNG_HEADER, + Buffer.from('fake-image-bytes'), + ]); + + beforeEach(() => { + vi.clearAllMocks(); + // statSync: must be a regular file under file size limit + mockStatSync.mockReturnValue({ + isFile: () => true, + size: fakeImageData.length, + } as unknown as ReturnType<(typeof import('node:fs'))['statSync']>); + // realpathSync: identity pass-through (restore default after + // validateImagePath tests may have overridden it). + mockRealpathSync.mockImplementation((p: string) => p); + // readFileSync: returns PNG-headed data for MIME check + full read + mockReadFileSync.mockReturnValue(fakeImageData); + }); + + it('completes the four-step upload and send flow', async () => { + mockGetUploadUrl.mockResolvedValue('upload-param-value'); + mockUploadToCdn.mockResolvedValue('cdn-encrypt-param'); + mockSendMessage.mockResolvedValue(undefined); + + await sendImage(defaultParams); + + // Step 1: validateImagePath uses openSync/readSync for magic-byte + // check (only 16 bytes), then sendImage calls readFileSync for + // full file read. + expect(mockReadFileSync).toHaveBeenCalledTimes(1); + expect(mockReadFileSync).toHaveBeenCalledWith('/tmp/test.png'); + + // Step 2: get upload URL called with correct params + const encryptedSize = Math.ceil((fakeImageData.length + 1) / 16) * 16; + const expectedFilekey = '42424242424242424242424242424242'; + const expectedAesKeyHex = '42424242424242424242424242424242'; + expect(mockGetUploadUrl).toHaveBeenCalledWith( + 'https://api.example.com', + 'token-abc', + 'user-123', + expectedFilekey, + fakeImageData.length, + computeMd5(fakeImageData), + encryptedSize, + expectedAesKeyHex, + ); + + // Step 3: upload to CDN (with real encryptAesEcb output) + const aesKeyBytes = Buffer.alloc(16, 0x42); + const expectedEncrypted = encryptAesEcb(fakeImageData, aesKeyBytes); + expect(mockUploadToCdn).toHaveBeenCalledWith( + 'upload-param-value', + expectedFilekey, + expectedEncrypted, + ); + + // Step 4: send message with image_item using CDN's x-encrypted-param + const expectedAesKeyBase64 = aesKeyBytes.toString('base64'); + expect(mockSendMessage).toHaveBeenCalledWith( + 'https://api.example.com', + 'token-abc', + expect.objectContaining({ + to_user_id: 'user-123', + context_token: 'ctx-456', + item_list: [ + expect.objectContaining({ + type: 2, // MessageItemType.IMAGE + image_item: expect.objectContaining({ + media: { + encrypt_query_param: 'cdn-encrypt-param', + aes_key: expectedAesKeyBase64, + encrypt_type: 1, + }, + }), + }), + ], + }), + ); + }); + + it('propagates getUploadUrl errors', async () => { + mockReadFileSync.mockReturnValue(fakeImageData); + mockGetUploadUrl.mockRejectedValue(new Error('Auth expired')); + + await expect(sendImage(defaultParams)).rejects.toThrow('Auth expired'); + expect(mockUploadToCdn).not.toHaveBeenCalled(); + expect(mockSendMessage).not.toHaveBeenCalled(); + }); + + it('propagates upload errors', async () => { + mockReadFileSync.mockReturnValue(fakeImageData); + mockGetUploadUrl.mockResolvedValue('upload-param-value'); + mockUploadToCdn.mockRejectedValue(new Error('CDN unavailable')); + + await expect(sendImage(defaultParams)).rejects.toThrow('CDN unavailable'); + expect(mockSendMessage).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/channels/weixin/src/send.ts b/packages/channels/weixin/src/send.ts index 54ca8fa52..27c8ab5fd 100644 --- a/packages/channels/weixin/src/send.ts +++ b/packages/channels/weixin/src/send.ts @@ -2,9 +2,20 @@ * Send messages to WeChat users. */ -import { randomUUID } from 'node:crypto'; -import { sendMessage } from './api.js'; +import { randomBytes, randomUUID } from 'node:crypto'; +import { + readFileSync, + statSync, + realpathSync, + openSync, + readSync, + closeSync, +} from 'node:fs'; +import { tmpdir } from 'node:os'; +import { resolve, extname } from 'node:path'; +import { sendMessage, getUploadUrl, uploadToCdn } from './api.js'; import { MessageType, MessageState, MessageItemType } from './types.js'; +import { encryptAesEcb, computeMd5 } from './media.js'; /** Convert markdown to plain text (WeChat doesn't support markdown) */ export function markdownToPlainText(text: string): string { @@ -29,6 +40,123 @@ export function markdownToPlainText(text: string): string { .trim(); } +// ── Image path validation ───────────────────────────────────────── + +const ALLOWED_EXTS = new Set(['.png', '.jpg', '.jpeg', '.gif', '.webp']); +const MAX_IMAGE_SIZE = 20 * 1024 * 1024; // 20 MB + +/** Image magic bytes → MIME type mapping. */ +export function detectImageMime(data: Buffer): string { + if ( + data[0] === 0x89 && + data[1] === 0x50 && + data[2] === 0x4e && + data[3] === 0x47 + ) { + return 'image/png'; + } + if (data[0] === 0x47 && data[1] === 0x49 && data[2] === 0x46) { + return 'image/gif'; + } + if ( + data[0] === 0x52 && + data[1] === 0x49 && + data[2] === 0x46 && + data[3] === 0x46 + ) { + return 'image/webp'; + } + if (data[0] === 0xff && data[1] === 0xd8 && data[2] === 0xff) { + return 'image/jpeg'; + } + throw new Error( + 'Unrecognized image format: magic bytes do not match any supported type', + ); +} + +/** + * Validate and resolve an image path before reading. + * + * Security: prevents AI-controlled [IMAGE: ...] markers from reading + * arbitrary files by enforcing directory allowlist, extension allowlist, + * size cap, and magic-byte verification. + * + * @param imagePath Raw path from the AI response. + * @param workspaceDirs Additional directories to allow (typically the cwd). + * @returns Resolved absolute realpath if valid. + */ +export function validateImagePath( + imagePath: string, + workspaceDirs: string[] = [], +): string { + const resolved = resolve(imagePath); + const ext = extname(resolved).toLowerCase(); + + if (!ALLOWED_EXTS.has(ext)) { + throw new Error(`Image extension not allowed: ${ext} (path: ${resolved})`); + } + + const real: string = (() => { + try { + return realpathSync(resolved); + } catch { + throw new Error(`Image file not found: ${resolved}`); + } + })(); + + const st = statSync(real); + if (!st.isFile()) { + throw new Error(`Not a regular file: ${real}`); + } + if (st.size > MAX_IMAGE_SIZE) { + throw new Error( + `Image too large: ${st.size} bytes (max ${MAX_IMAGE_SIZE})`, + ); + } + + // Build the allowlist: /tmp/ (and macOS real /private/tmp/), os.tmpdir(), + // plus workspace directories passed by the caller. Use realpathSync to + // resolve symlinks (e.g. /tmp → /private/tmp on macOS). + const ALLOWED_DIRS = [ + '/tmp/', + realpathSync('/tmp/') + '/', + tmpdir() + '/', + realpathSync(tmpdir()) + '/', + ...workspaceDirs.map((d) => realpathSync(resolve(d)) + '/'), + ]; + + if (!ALLOWED_DIRS.some((dir) => real.startsWith(dir))) { + throw new Error(`Image path outside allowed directories: ${real}`); + } + + // Verify magic bytes match the extension (read only first 16 bytes to + // avoid TOCTOU double-read — sendImage reads the full file later). + let fd: number | undefined; + try { + fd = openSync(real, 'r'); + const head = Buffer.alloc(16); + const bytesRead = readSync(fd, head, 0, 16, 0); + const mime = detectImageMime(head.slice(0, bytesRead)); + const extToExpectedMime: Record = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.webp': 'image/webp', + }; + const expected = extToExpectedMime[ext]; + if (mime !== expected) { + throw new Error( + `Image type mismatch: ext=${ext} expects ${expected} but got ${mime}`, + ); + } + } finally { + if (fd !== undefined) closeSync(fd); + } + + return real; +} + /** Send a text message */ export async function sendText(params: { to: string; @@ -50,3 +178,81 @@ export async function sendText(params: { item_list: [{ type: MessageItemType.TEXT, text_item: { text: plainText } }], }); } + +/** + * Send an image message via the four-step CDN upload flow: + * 1. Validate path + read file, compute rawsize + MD5; generate AES key + filekey + * 2. Request upload URL via getuploadurl + * 3. AES-128-ECB encrypt + POST upload to CDN; extract x-encrypted-param + * 4. Send message with image_item referencing the CDN media + */ +export async function sendImage(params: { + to: string; + imagePath: string; + baseUrl: string; + token: string; + contextToken: string; + /** Workspace directories to allow for image paths. */ + workspaceDirs?: string[]; +}): Promise { + const { to, imagePath, baseUrl, token, contextToken, workspaceDirs } = params; + + // Step 1 (security): validate and resolve the image path + const resolvedPath = validateImagePath(imagePath, workspaceDirs); + + // Step 1 (continued): read file, compute metadata + generate random identifiers + const fileBuffer = readFileSync(resolvedPath); + const rawsize = fileBuffer.length; + const rawfilemd5 = computeMd5(fileBuffer); + + // Generate random 16-byte AES key as hex string + const aesKeyBytes = randomBytes(16); + const aesKeyHex = aesKeyBytes.toString('hex'); + + // Generate random 32-char hex filekey + const filekey = randomBytes(16).toString('hex'); + + // AES-128-ECB PKCS#7 padding: encrypted size = ceil((rawsize + 1) / 16) * 16 + const encryptedSize = Math.ceil((rawsize + 1) / 16) * 16; + + // Step 2: get upload URL and CDN credentials + const uploadParam = await getUploadUrl( + baseUrl, + token, + to, + filekey, + rawsize, + rawfilemd5, + encryptedSize, + aesKeyHex, + ); + + // Step 3: encrypt and upload to CDN + const encrypted = encryptAesEcb(fileBuffer, aesKeyBytes); + const cdnEncryptParam = await uploadToCdn(uploadParam, filekey, encrypted); + + // Step 4: send message with image_item using CDN's x-encrypted-param + // aes_key: base64(raw 16 bytes) for images per protocol + const aesKeyBase64 = aesKeyBytes.toString('base64'); + + await sendMessage(baseUrl, token, { + to_user_id: to, + from_user_id: '', + client_id: randomUUID(), + message_type: MessageType.BOT, + message_state: MessageState.FINISH, + context_token: contextToken, + item_list: [ + { + type: MessageItemType.IMAGE, + image_item: { + media: { + encrypt_query_param: cdnEncryptParam, + aes_key: aesKeyBase64, + encrypt_type: 1, + }, + }, + }, + ], + }); +} diff --git a/packages/cli/src/acp-integration/session/Session.test.ts b/packages/cli/src/acp-integration/session/Session.test.ts index d324ee697..99f34d2eb 100644 --- a/packages/cli/src/acp-integration/session/Session.test.ts +++ b/packages/cli/src/acp-integration/session/Session.test.ts @@ -46,6 +46,20 @@ function createStreamWithChunks( })(); } +function expectCompressBeforeSend( + compressMock: ReturnType, + sendMock: ReturnType, + callIndex: number, +) { + expect(compressMock.mock.invocationCallOrder.length).toBeGreaterThan( + callIndex, + ); + expect(sendMock.mock.invocationCallOrder.length).toBeGreaterThan(callIndex); + expect(compressMock.mock.invocationCallOrder[callIndex]).toBeLessThan( + sendMock.mock.invocationCallOrder[callIndex], + ); +} + describe('Session', () => { let mockChat: GeminiChat; let mockConfig: Config; @@ -56,6 +70,10 @@ describe('Session', () => { let currentAuthType: AuthType; let switchModelSpy: ReturnType; let getAvailableCommandsSpy: ReturnType; + let mockGeminiClient: { + getChat: ReturnType; + tryCompressChat: ReturnType; + }; let mockToolRegistry: { getTool: ReturnType; ensureTool: ReturnType; @@ -75,6 +93,14 @@ describe('Session', () => { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), } as unknown as GeminiChat; + mockGeminiClient = { + getChat: vi.fn().mockReturnValue(mockChat), + tryCompressChat: vi.fn().mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: core.CompressionStatus.NOOP, + }), + }; mockToolRegistry = { getTool: vi.fn(), @@ -102,6 +128,7 @@ describe('Session', () => { recordUserMessage: vi.fn(), recordUiTelemetryEvent: vi.fn(), recordToolResult: vi.fn(), + recordSlashCommand: vi.fn(), }), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), // #buildInitialSystemReminders iterates listSubagents() on every @@ -117,9 +144,8 @@ describe('Session', () => { getDebugMode: vi.fn().mockReturnValue(false), getAuthType: vi.fn().mockImplementation(() => currentAuthType), isCronEnabled: vi.fn().mockReturnValue(false), - getGeminiClient: vi - .fn() - .mockReturnValue({ getChat: vi.fn().mockReturnValue(mockChat) }), + getSessionTokenLimit: vi.fn().mockReturnValue(0), + getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), } as unknown as Config; mockClient = { @@ -155,6 +181,7 @@ describe('Session', () => { mockConfig = undefined as unknown as Config; mockClient = undefined as unknown as AgentSideConnection; mockSettings = undefined as unknown as LoadedSettings; + mockGeminiClient = undefined as unknown as typeof mockGeminiClient; mockToolRegistry = undefined as unknown as typeof mockToolRegistry; vi.restoreAllMocks(); vi.clearAllTimers(); @@ -328,6 +355,1012 @@ describe('Session', () => { }); describe('prompt', () => { + describe('auto-compress', () => { + it('runs automatic compression before sending an ACP prompt', async () => { + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledWith( + 'test-session-id########1', + false, + expect.any(AbortSignal), + ); + + const sendMessageStream = mockChat.sendMessageStream as ReturnType< + typeof vi.fn + >; + expectCompressBeforeSend( + mockGeminiClient.tryCompressChat, + sendMessageStream, + 0, + ); + }); + + it('uses the current chat after automatic compression replaces it', async () => { + const compressedChat = { + sendMessageStream: vi.fn().mockResolvedValue(createEmptyStream()), + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + } as unknown as GeminiChat; + + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + mockGeminiClient.tryCompressChat.mockImplementation(async () => { + mockGeminiClient.getChat.mockReturnValue(compressedChat); + return { + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: core.CompressionStatus.COMPRESSED, + }; + }); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + expect(compressedChat.sendMessageStream).toHaveBeenCalledWith( + 'qwen3-code-plus', + { + message: expect.any(Array), + config: { abortSignal: expect.any(AbortSignal) }, + }, + 'test-session-id########1', + ); + }); + + it('emits an ACP-visible update when automatic compression succeeds', async () => { + mockGeminiClient.tryCompressChat.mockResolvedValueOnce({ + originalTokenCount: 1200, + newTokenCount: 450, + compressionStatus: core.CompressionStatus.COMPRESSED, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'IMPORTANT: This conversation approached the input token limit for qwen3-code-plus. ' + + 'A compressed context will be sent for future messages (compressed from: 1200 to 450 tokens).', + }, + }, + }); + }); + + it('continues sending when automatic compression fails', async () => { + mockGeminiClient.tryCompressChat.mockRejectedValueOnce( + new Error('compression rate limited'), + ); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledWith( + 'test-session-id########1', + false, + expect.any(AbortSignal), + ); + expect(mockChat.sendMessageStream).toHaveBeenCalledWith( + 'qwen3-code-plus', + { + message: expect.any(Array), + config: { abortSignal: expect.any(AbortSignal) }, + }, + 'test-session-id########1', + ); + }); + + it('does not use global UI telemetry when compression fails before local token counts exist', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + vi.spyOn( + core.uiTelemetryService, + 'getLastPromptTokenCount', + ).mockReturnValue(101); + mockGeminiClient.tryCompressChat.mockRejectedValueOnce( + new Error('compression rate limited'), + ); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(mockClient.sessionUpdate).not.toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'agent_message_chunk', + content: expect.objectContaining({ + text: expect.stringContaining('Session token limit exceeded'), + }), + }), + }), + ); + }); + + it('returns cancelled when automatic compression is aborted', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat.mockImplementation( + async (_promptId: string, _force: boolean, signal: AbortSignal) => + new Promise((_, reject) => { + signal.addEventListener('abort', () => { + const abortError = new Error('aborted'); + abortError.name = 'AbortError'; + reject(abortError); + }); + }), + ); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + const promptPromise = session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + await vi.waitFor(() => { + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalled(); + }); + + await session.cancelPendingPrompt(); + + await expect(promptPromise).resolves.toEqual({ + stopReason: 'cancelled', + }); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + expect(mockChat.addHistory).toHaveBeenCalledWith({ + role: 'user', + parts: expect.any(Array), + }); + expect(mockClient.sessionUpdate).not.toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'Session token limit exceeded: 101 tokens > 100 limit. ' + + 'Please start a new session or increase the sessionTokenLimit in your settings.json.', + }, + }, + }); + }); + + it('uses compression token info instead of global UI telemetry for the session limit', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + vi.spyOn( + core.uiTelemetryService, + 'getLastPromptTokenCount', + ).mockReturnValue(999); + mockGeminiClient.tryCompressChat.mockResolvedValueOnce({ + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: core.CompressionStatus.NOOP, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('falls back to the previous prompt token count when compression returns zero token info', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat.mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: core.CompressionStatus.NOOP, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + usageMetadata: { + totalTokenCount: 101, + promptTokenCount: 101, + }, + }, + }, + ]), + ) + .mockResolvedValueOnce(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'first' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'second' }], + }), + ).resolves.toEqual({ stopReason: 'max_tokens' }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('falls back to the previous prompt token count when compressed token info is zero', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat + .mockResolvedValueOnce({ + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: core.CompressionStatus.NOOP, + }) + .mockResolvedValueOnce({ + originalTokenCount: 1200, + newTokenCount: 0, + compressionStatus: core.CompressionStatus.COMPRESSED, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + usageMetadata: { + totalTokenCount: 101, + promptTokenCount: 101, + }, + }, + }, + ]), + ) + .mockResolvedValueOnce(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'first' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'second' }], + }), + ).resolves.toEqual({ stopReason: 'max_tokens' }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('records prompt token count instead of total token count for later session-limit checks', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat + .mockResolvedValueOnce({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: core.CompressionStatus.NOOP, + }) + .mockRejectedValueOnce(new Error('compression unavailable')); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + usageMetadata: { + totalTokenCount: 500, + promptTokenCount: 50, + }, + }, + }, + ]), + ) + .mockResolvedValueOnce(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'long response' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'next prompt' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + }); + + it('resets the session-local token count when the active chat instance changes', async () => { + const clearedChat = { + sendMessageStream: vi.fn().mockResolvedValue(createEmptyStream()), + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + } as unknown as GeminiChat; + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat + .mockResolvedValueOnce({ + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: core.CompressionStatus.NOOP, + }) + .mockRejectedValueOnce(new Error('compression unavailable')); + mockChat.sendMessageStream = vi.fn().mockResolvedValueOnce( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + usageMetadata: { + totalTokenCount: 500, + promptTokenCount: 101, + }, + }, + }, + ]), + ); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'before clear' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + + mockGeminiClient.getChat.mockReturnValue(clearedChat); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'after clear' }], + }), + ).resolves.toEqual({ stopReason: 'end_turn' }); + + expect(clearedChat.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('continues sending when the compression notification fails', async () => { + mockGeminiClient.tryCompressChat.mockResolvedValueOnce({ + originalTokenCount: 1200, + newTokenCount: 450, + compressionStatus: core.CompressionStatus.COMPRESSED, + }); + mockClient.sessionUpdate = vi + .fn() + .mockRejectedValueOnce(new Error('client disconnected')); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('stops before sending when the compressed prompt exceeds the session token limit', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat.mockResolvedValueOnce({ + originalTokenCount: 1200, + newTokenCount: 101, + compressionStatus: core.CompressionStatus.COMPRESSED, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).resolves.toEqual({ stopReason: 'max_tokens' }); + + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalled(); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + expect(mockChat.addHistory).not.toHaveBeenCalled(); + expect(mockClient.sessionUpdate).not.toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'IMPORTANT: This conversation approached the input token limit for qwen3-code-plus. ' + + 'A compressed context will be sent for future messages (compressed from: 1200 to 101 tokens).', + }, + }, + }); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'Session token limit exceeded: 101 tokens > 100 limit. ' + + 'Please start a new session or increase the sessionTokenLimit in your settings.json.', + }, + }, + }); + }); + + it('stops without throwing when the token-limit diagnostic fails', async () => { + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat.mockResolvedValueOnce({ + originalTokenCount: 101, + newTokenCount: 101, + compressionStatus: core.CompressionStatus.NOOP, + }); + mockClient.sessionUpdate = vi + .fn() + .mockRejectedValueOnce(new Error('client disconnected')); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).resolves.toEqual({ stopReason: 'max_tokens' }); + + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + expect(mockChat.addHistory).not.toHaveBeenCalled(); + }); + + it('also runs automatic compression before tool response follow-up sends', async () => { + const executeSpy = vi.fn().mockResolvedValue({ + llmContent: 'file contents', + returnDisplay: 'file contents', + }); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + getDescription: vi.fn().mockReturnValue('Read file'), + toolLocations: vi.fn().mockReturnValue([]), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + mockConfig.getApprovalMode = vi.fn().mockReturnValue(ApprovalMode.YOLO); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ) + .mockResolvedValueOnce(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read file' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + 'test-session-id########1', + false, + expect.any(AbortSignal), + ); + + const sendMessageStream = mockChat.sendMessageStream as ReturnType< + typeof vi.fn + >; + expectCompressBeforeSend( + mockGeminiClient.tryCompressChat, + sendMessageStream, + 1, + ); + }); + + it('stops tool response follow-up before sending when the session token limit is exceeded', async () => { + const executeSpy = vi.fn().mockResolvedValue({ + llmContent: 'file contents', + returnDisplay: 'file contents', + }); + const tool = { + name: 'read_file', + kind: core.Kind.Read, + build: vi.fn().mockReturnValue({ + params: { path: '/tmp/test.txt' }, + getDefaultPermission: vi.fn().mockResolvedValue('allow'), + getDescription: vi.fn().mockReturnValue('Read file'), + toolLocations: vi.fn().mockReturnValue([]), + execute: executeSpy, + }), + }; + + mockToolRegistry.getTool.mockReturnValue(tool); + mockConfig.getApprovalMode = vi.fn().mockReturnValue(ApprovalMode.YOLO); + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat + .mockResolvedValueOnce({ + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: core.CompressionStatus.NOOP, + }) + .mockResolvedValueOnce({ + originalTokenCount: 101, + newTokenCount: 101, + compressionStatus: core.CompressionStatus.NOOP, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce( + createStreamWithChunks([ + { + type: core.StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'call-1', + name: 'read_file', + args: { path: '/tmp/test.txt' }, + }, + ], + }, + }, + ]), + ) + .mockResolvedValueOnce(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'read file' }], + }), + ).resolves.toEqual({ stopReason: 'max_tokens' }); + + expect(executeSpy).toHaveBeenCalledTimes(1); + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + 'test-session-id########1', + false, + expect.any(AbortSignal), + ); + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(mockChat.addHistory).toHaveBeenCalledWith({ + role: 'user', + parts: [ + expect.objectContaining({ + functionResponse: expect.objectContaining({ + id: 'call-1', + name: 'read_file', + }), + }), + ], + }); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'Session token limit exceeded: 101 tokens > 100 limit. ' + + 'Please start a new session or increase the sessionTokenLimit in your settings.json.', + }, + }, + }); + }); + + it('runs automatic compression before Stop-hook continuation sends', async () => { + const messageBus = { + request: vi + .fn() + .mockResolvedValueOnce({ + success: true, + output: { + decision: 'block', + reason: 'Continue after Stop hook', + }, + }) + .mockResolvedValueOnce({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi + .fn() + .mockImplementation((eventName: string) => eventName === 'Stop'); + mockChat.getHistory = vi + .fn() + .mockReturnValue([ + { role: 'model', parts: [{ text: 'response text' }] }, + ]); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockResolvedValueOnce(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + 'test-session-id########1_stop_hook_1', + false, + expect.any(AbortSignal), + ); + + const sendMessageStream = mockChat.sendMessageStream as ReturnType< + typeof vi.fn + >; + expectCompressBeforeSend( + mockGeminiClient.tryCompressChat, + sendMessageStream, + 1, + ); + }); + + it('skips automatic compression after the first Stop-hook continuation', async () => { + const messageBus = { + request: vi + .fn() + .mockResolvedValueOnce({ + success: true, + output: { + decision: 'block', + reason: 'Continue after first Stop hook', + }, + }) + .mockResolvedValueOnce({ + success: true, + output: { + decision: 'block', + reason: 'Continue after second Stop hook', + }, + }) + .mockResolvedValueOnce({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi + .fn() + .mockImplementation((eventName: string) => eventName === 'Stop'); + mockChat.getHistory = vi + .fn() + .mockReturnValue([ + { role: 'model', parts: [{ text: 'response text' }] }, + ]); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockResolvedValueOnce(createEmptyStream()) + .mockResolvedValueOnce(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(3); + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + 'test-session-id########1_stop_hook_1', + false, + expect.any(AbortSignal), + ); + expect(mockGeminiClient.tryCompressChat).not.toHaveBeenCalledWith( + 'test-session-id########1_stop_hook_2', + false, + expect.any(AbortSignal), + ); + + const sendMessageStream = mockChat.sendMessageStream as ReturnType< + typeof vi.fn + >; + expect(sendMessageStream.mock.calls[2]?.[2]).toBe( + 'test-session-id########1_stop_hook_2', + ); + }); + + it('stops Stop-hook continuation before sending when the session token limit is exceeded', async () => { + const messageBus = { + request: vi + .fn() + .mockResolvedValueOnce({ + success: true, + output: { + decision: 'block', + reason: 'Continue after Stop hook', + }, + }) + .mockResolvedValueOnce({ + success: true, + output: {}, + }), + }; + mockConfig.getMessageBus = vi.fn().mockReturnValue(messageBus); + mockConfig.getDisableAllHooks = vi.fn().mockReturnValue(false); + mockConfig.hasHooksForEvent = vi + .fn() + .mockImplementation((eventName: string) => eventName === 'Stop'); + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat + .mockResolvedValueOnce({ + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: core.CompressionStatus.NOOP, + }) + .mockResolvedValueOnce({ + originalTokenCount: 101, + newTokenCount: 101, + compressionStatus: core.CompressionStatus.NOOP, + }); + mockChat.getHistory = vi + .fn() + .mockReturnValue([ + { role: 'model', parts: [{ text: 'response text' }] }, + ]); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await expect( + session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }), + ).resolves.toEqual({ stopReason: 'max_tokens' }); + + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + 'test-session-id########1_stop_hook_1', + false, + expect.any(AbortSignal), + ); + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'Session token limit exceeded: 101 tokens > 100 limit. ' + + 'Please start a new session or increase the sessionTokenLimit in your settings.json.', + }, + }, + }); + }); + + it('runs automatic compression before cron-fired ACP prompt sends', async () => { + const scheduler = { + size: 1, + start: vi.fn((callback: (job: { prompt: string }) => void) => { + callback({ prompt: 'scheduled prompt' }); + }), + stop: vi.fn(), + getExitSummary: vi.fn().mockReturnValue(undefined), + }; + mockConfig.isCronEnabled = vi.fn().mockReturnValue(true); + mockConfig.getCronScheduler = vi.fn().mockReturnValue(scheduler); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValueOnce(createEmptyStream()) + .mockResolvedValueOnce(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + await vi.waitFor(() => { + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + }); + + expect(scheduler.start).toHaveBeenCalledTimes(1); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 1, + 'test-session-id########1', + false, + expect.any(AbortSignal), + ); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + expect.stringMatching(/^test-session-id########cron\d+$/), + false, + expect.any(AbortSignal), + ); + + const sendMessageStream = mockChat.sendMessageStream as ReturnType< + typeof vi.fn + >; + expectCompressBeforeSend( + mockGeminiClient.tryCompressChat, + sendMessageStream, + 1, + ); + }); + + it('stops cron-fired ACP prompt before sending when the session token limit is exceeded', async () => { + let cronCallback: ((job: { prompt: string }) => void) | undefined; + const scheduler = { + size: 1, + start: vi.fn((callback: (job: { prompt: string }) => void) => { + cronCallback = callback; + callback({ prompt: 'scheduled prompt' }); + }), + stop: vi.fn(), + getExitSummary: vi.fn().mockReturnValue(undefined), + }; + mockConfig.isCronEnabled = vi.fn().mockReturnValue(true); + mockConfig.getCronScheduler = vi.fn().mockReturnValue(scheduler); + mockConfig.getSessionTokenLimit = vi.fn().mockReturnValue(100); + mockGeminiClient.tryCompressChat + .mockResolvedValueOnce({ + originalTokenCount: 50, + newTokenCount: 50, + compressionStatus: core.CompressionStatus.NOOP, + }) + .mockResolvedValueOnce({ + originalTokenCount: 101, + newTokenCount: 101, + compressionStatus: core.CompressionStatus.NOOP, + }); + mockChat.sendMessageStream = vi + .fn() + .mockResolvedValue(createEmptyStream()); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: 'hello' }], + }); + + await vi.waitFor(() => { + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledTimes(2); + }); + + expect(scheduler.start).toHaveBeenCalledTimes(1); + expect(mockGeminiClient.tryCompressChat).toHaveBeenNthCalledWith( + 2, + expect.stringMatching(/^test-session-id########cron\d+$/), + false, + expect.any(AbortSignal), + ); + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: + 'Session token limit exceeded: 101 tokens > 100 limit. ' + + 'Please start a new session or increase the sessionTokenLimit in your settings.json.', + }, + }, + }); + expect(scheduler.stop).toHaveBeenCalledTimes(1); + await vi.waitFor(() => { + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'test-session-id', + update: { + sessionUpdate: 'agent_message_chunk', + content: { + type: 'text', + text: 'Cron jobs disabled for the rest of this session due to token limit. Restart the session to re-enable.', + }, + }, + }); + }); + + const sessionUpdateMock = mockClient.sessionUpdate as ReturnType< + typeof vi.fn + >; + const tokenLimitDiagnosticCount = () => + sessionUpdateMock.mock.calls.filter((call) => { + const notification = call[0] as { + update?: { + sessionUpdate?: string; + content?: { type?: string; text?: string }; + }; + }; + return ( + notification.update?.sessionUpdate === 'agent_message_chunk' && + notification.update.content?.type === 'text' && + notification.update.content.text?.includes( + 'Session token limit exceeded', + ) + ); + }).length; + const diagnosticCountBefore = tokenLimitDiagnosticCount(); + + cronCallback?.({ prompt: 'scheduled prompt again' }); + await Promise.resolve(); + + expect(mockGeminiClient.tryCompressChat).toHaveBeenCalledTimes(2); + expect(tokenLimitDiagnosticCount()).toBe(diagnosticCountBefore); + }); + + it('does not auto-compress slash commands handled without a model send', async () => { + vi.mocked( + nonInteractiveCliCommands.handleSlashCommand, + ).mockResolvedValueOnce({ + type: 'message', + messageType: 'info', + content: 'Already compressed.', + }); + mockChat.sendMessageStream = vi.fn(); + + await session.prompt({ + sessionId: 'test-session-id', + prompt: [{ type: 'text', text: '/compress' }], + }); + + expect(mockGeminiClient.tryCompressChat).not.toHaveBeenCalled(); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + }); + }); + it('passes resolved paths to read_many_files tool', async () => { const tempDir = await fs.mkdtemp( path.join(os.tmpdir(), 'qwen-acp-session-'), diff --git a/packages/cli/src/acp-integration/session/Session.ts b/packages/cli/src/acp-integration/session/Session.ts index 299754cd2..9b1c77e04 100644 --- a/packages/cli/src/acp-integration/session/Session.ts +++ b/packages/cli/src/acp-integration/session/Session.ts @@ -21,10 +21,13 @@ import type { HookExecutionRequest, HookExecutionResponse, MessageBus, + StreamEvent, + ChatCompressionInfo, } from '@qwen-code/qwen-code-core'; import { AuthType, ApprovalMode, + CompressionStatus, convertToFunctionResponse, createDebugLogger, DiscoveredMCPTool, @@ -108,6 +111,10 @@ import { const debugLogger = createDebugLogger('SESSION'); +type AutoCompressionSendResult = + | { responseStream: AsyncGenerator; stopReason?: never } + | { responseStream: null; stopReason: PromptResponse['stopReason'] }; + /** * Session represents an active conversation session with the AI model. * It uses modular components for consistent event emission: @@ -134,6 +141,9 @@ export class Session implements SessionContext { private cronProcessing = false; private cronAbortController: AbortController | null = null; private cronCompletion: Promise | null = null; + private cronDisabledByTokenLimit = false; + private lastPromptTokenCount = 0; + private lastPromptTokenCountChat: GeminiChat | null = null; // Modular components private readonly historyReplayer: HistoryReplayer; @@ -295,9 +305,6 @@ export class Session implements SessionContext { // Increment turn counter for each user prompt this.turn += 1; - // Always fetch the current chat from GeminiClient so that /clear's - // resetChat() (which replaces the chat instance) is reflected here. - const chat = this.config.getGeminiClient()!.getChat(); const promptId = this.config.getSessionId() + '########' + this.turn; // Extract text from all text blocks to construct the full prompt text for logging @@ -413,7 +420,7 @@ export class Session implements SessionContext { while (nextMessage !== null) { if (pendingSend.signal.aborted) { - chat.addHistory(nextMessage); + this.#getCurrentChat().addHistory(nextMessage); return { stopReason: 'cancelled' }; } @@ -422,16 +429,19 @@ export class Session implements SessionContext { const streamStartTime = Date.now(); try { - const responseStream = await chat.sendMessageStream( - this.config.getModel(), - { - message: nextMessage?.parts ?? [], - config: { - abortSignal: pendingSend.signal, - }, - }, + const sendResult = await this.#sendMessageStreamWithAutoCompression( promptId, + nextMessage?.parts ?? [], + pendingSend.signal, ); + if (!sendResult.responseStream) { + this.#preserveUnsentMessageHistory( + nextMessage, + sendResult.stopReason === 'cancelled', + ); + return { stopReason: sendResult.stopReason }; + } + const responseStream = sendResult.responseStream; nextMessage = null; for await (const resp of responseStream) { @@ -510,6 +520,7 @@ export class Session implements SessionContext { } if (usageMetadata) { + this.#recordPromptTokenCount(usageMetadata); // Kick off rewrite in background (non-blocking, runs parallel to tools) if (this.messageRewriter) { this.messageRewriter.flushTurn(pendingSend.signal); @@ -541,7 +552,6 @@ export class Session implements SessionContext { // Fire Stop hook loop (aligned with core path in client.ts) // This is triggered after model response completes with no pending tool calls return this.#handleStopHookLoop( - chat, pendingSend, promptId, hooksEnabled, @@ -557,20 +567,18 @@ export class Session implements SessionContext { * If a Stop hook requests continuation, it sends a follow-up message and loops back. * Maximum iterations (100) prevent infinite loops. * - * @param chat - The GeminiChat instance * @param pendingSend - The abort controller for the current prompt * @param promptId - The prompt ID for tracking * @param hooksEnabled - Whether hooks are enabled * @param messageBus - The MessageBus for hook communication (may be undefined) - * @returns The stop reason ('end_turn' or 'cancelled') + * @returns The ACP stop reason for the prompt. */ async #handleStopHookLoop( - chat: GeminiChat, pendingSend: AbortController, promptId: string, hooksEnabled: boolean, messageBus: MessageBus | undefined, - ): Promise<{ stopReason: 'end_turn' | 'cancelled' }> { + ): Promise<{ stopReason: PromptResponse['stopReason'] }> { const MAX_STOP_HOOK_ITERATIONS = 100; let stopHookIterationCount = 0; let stopHookReasons: string[] = []; @@ -586,7 +594,7 @@ export class Session implements SessionContext { } // Get response text from the chat history - const history = chat.getHistory(); + const history = this.#getCurrentChat().getHistory(); const lastModelMessage = history .filter((msg: Content) => msg.role === 'model') .pop(); @@ -666,16 +674,21 @@ export class Session implements SessionContext { const streamStartTime = Date.now(); try { - const continueResponseStream = await chat.sendMessageStream( - this.config.getModel(), - { - message: nextMessage?.parts ?? [], - config: { - abortSignal: pendingSend.signal, - }, - }, - promptId + '_stop_hook_' + stopHookIterationCount, - ); + const continueSendResult = + await this.#sendMessageStreamWithAutoCompression( + promptId + '_stop_hook_' + stopHookIterationCount, + nextMessage?.parts ?? [], + pendingSend.signal, + { skipCompression: stopHookIterationCount > 1 }, + ); + if (!continueSendResult.responseStream) { + this.#preserveUnsentMessageHistory( + nextMessage, + continueSendResult.stopReason === 'cancelled', + ); + return { stopReason: continueSendResult.stopReason }; + } + const continueResponseStream = continueSendResult.responseStream; nextMessage = null; for await (const resp of continueResponseStream) { @@ -749,6 +762,7 @@ export class Session implements SessionContext { } if (usageMetadata) { + this.#recordPromptTokenCount(usageMetadata); const durationMs = Date.now() - streamStartTime; await this.messageEmitter.emitUsageMetadata( usageMetadata, @@ -795,6 +809,245 @@ export class Session implements SessionContext { await this.client.sessionUpdate(params); } + #getCurrentChat(): GeminiChat { + return this.config.getGeminiClient()!.getChat(); + } + + /** + * Mirrors the core send path for ACP model sends. + * + * Attempts automatic chat compression first, checks the session token limit, + * emits an ACP-visible notice when compression succeeds, and returns the ACP + * stop reason when the provider send should be skipped because the request + * was cancelled or the session token limit was exceeded. + */ + async #sendMessageStreamWithAutoCompression( + promptId: string, + message: Part[], + abortSignal: AbortSignal, + options: { skipCompression?: boolean } = {}, + ): Promise { + const geminiClient = this.config.getGeminiClient()!; + let compressionDiagnostic: string | null = null; + let compressionInfo: ChatCompressionInfo | null = null; + if (!options.skipCompression) { + try { + const compressed = await geminiClient.tryCompressChat( + promptId, + false, + abortSignal, + ); + compressionInfo = compressed; + this.#recordCompressionTokenCount(compressed); + if (compressed.compressionStatus === CompressionStatus.COMPRESSED) { + compressionDiagnostic = + `IMPORTANT: This conversation approached the input token limit for ${this.config.getModel()}. ` + + `A compressed context will be sent for future messages (compressed from: ` + + `${compressed.originalTokenCount ?? 'unknown'} to ` + + `${compressed.newTokenCount ?? 'unknown'} tokens).`; + } + } catch (compressionError) { + if (abortSignal.aborted || this.#isAbortError(compressionError)) { + debugLogger.debug(`Auto-compression aborted for prompt ${promptId}`); + return { responseStream: null, stopReason: 'cancelled' }; + } + debugLogger.warn( + `Auto-compression failed for prompt ${promptId}; proceeding without compression: ` + + this.#formatError(compressionError), + ); + } + } + + if (abortSignal.aborted) { + debugLogger.debug(`Auto-compression aborted for prompt ${promptId}`); + return { responseStream: null, stopReason: 'cancelled' }; + } + + if (!compressionInfo) { + this.#syncPromptTokenCountWithCurrentChat(); + } + + const sessionTokenLimit = this.config.getSessionTokenLimit(); + if (sessionTokenLimit > 0) { + const lastPromptTokenCount = + this.#getPostCompressionTokenCount(compressionInfo); + if (lastPromptTokenCount > sessionTokenLimit) { + debugLogger.warn( + `Session token limit exceeded for prompt ${promptId}: ` + + `${lastPromptTokenCount} > ${sessionTokenLimit}. Send dropped.`, + ); + await this.#emitAgentDiagnosticMessageSafely( + `Session token limit exceeded: ${lastPromptTokenCount} tokens > ${sessionTokenLimit} limit. ` + + 'Please start a new session or increase the sessionTokenLimit in your settings.json.', + `Failed to emit token limit diagnostic for prompt ${promptId}`, + ); + return { responseStream: null, stopReason: 'max_tokens' }; + } + } + + if (compressionDiagnostic) { + await this.#emitAgentDiagnosticMessageSafely( + compressionDiagnostic, + `Failed to emit compression notification for prompt ${promptId}`, + ); + } + + if (abortSignal.aborted) { + debugLogger.debug( + `Send aborted after compression diagnostic for prompt ${promptId}`, + ); + return { responseStream: null, stopReason: 'cancelled' }; + } + + const responseStream = await this.#getCurrentChat().sendMessageStream( + this.config.getModel(), + { + message, + config: { + abortSignal, + }, + }, + promptId, + ); + return { responseStream }; + } + + #preserveUnsentMessageHistory( + message: Content | null, + preserveFullMessage: boolean, + ): void { + if (!message) return; + + if (preserveFullMessage) { + this.#getCurrentChat().addHistory(message); + return; + } + + const functionResponseParts = + message.parts?.filter( + (part: Part) => 'functionResponse' in part && part.functionResponse, + ) ?? []; + const droppedParts = + (message.parts?.length ?? 0) - functionResponseParts.length; + if (droppedParts > 0) { + debugLogger.debug( + `Dropping ${droppedParts} non-functionResponse part(s) from unsent ACP message after send was skipped.`, + ); + } + if (functionResponseParts.length > 0) { + this.#getCurrentChat().addHistory({ + ...message, + parts: functionResponseParts, + }); + } + } + + #recordCompressionTokenCount(info: ChatCompressionInfo): void { + this.#syncPromptTokenCountWithCurrentChat(); + const tokenCount = this.#extractCompressionTokenCount(info); + if (tokenCount !== null && tokenCount > 0) { + this.lastPromptTokenCount = tokenCount; + } + } + + #recordPromptTokenCount( + usageMetadata: GenerateContentResponseUsageMetadata, + ): void { + this.#syncPromptTokenCountWithCurrentChat(); + const tokenCount = + usageMetadata.promptTokenCount ?? usageMetadata.totalTokenCount; + if (tokenCount !== undefined && tokenCount > 0) { + this.lastPromptTokenCount = tokenCount; + } + } + + #getPostCompressionTokenCount(info: ChatCompressionInfo | null): number { + const tokenCount = this.#extractCompressionTokenCount(info); + if (tokenCount !== null) { + return tokenCount; + } + + return this.lastPromptTokenCount; + } + + #extractCompressionTokenCount( + info: ChatCompressionInfo | null, + ): number | null { + if (!info) { + return null; + } + if (info.compressionStatus === CompressionStatus.COMPRESSED) { + return info.newTokenCount > 0 ? info.newTokenCount : null; + } + const tokenCount = info.originalTokenCount ?? info.newTokenCount ?? null; + if (tokenCount === 0 && info.compressionStatus === CompressionStatus.NOOP) { + return null; + } + return tokenCount; + } + + #syncPromptTokenCountWithCurrentChat(): void { + const chat = this.#getCurrentChat(); + if ( + this.lastPromptTokenCountChat && + this.lastPromptTokenCountChat !== chat + ) { + this.lastPromptTokenCount = 0; + } + this.lastPromptTokenCountChat = chat; + } + + #isAbortError(error: unknown): boolean { + return ( + (error instanceof Error && error.name === 'AbortError') || + (typeof DOMException !== 'undefined' && + error instanceof DOMException && + error.name === 'AbortError') || + (typeof error === 'object' && + error !== null && + 'name' in error && + (error as { name?: unknown }).name === 'AbortError') + ); + } + + #formatError(error: unknown): string { + if (error instanceof Error) { + const parts = [error.message]; + const cause = (error as Error & { cause?: unknown }).cause; + if (cause instanceof Error) { + parts.push(`cause: ${cause.message}`); + } + const status = (error as Error & { status?: unknown }).status; + if (status !== undefined) { + parts.push(`status: ${String(status)}`); + } + return parts.join(' | '); + } + try { + return JSON.stringify(error) ?? String(error); + } catch { + return String(error); + } + } + + async #emitAgentDiagnosticMessageSafely( + text: string, + failureContext: string, + ): Promise { + try { + await this.#emitAgentDiagnosticMessage(text); + } catch (notifyError) { + debugLogger.warn(`${failureContext}: ${this.#formatError(notifyError)}`); + } + } + + async #emitAgentDiagnosticMessage(text: string): Promise { + await this.sendUpdate({ + sessionUpdate: 'agent_message_chunk', + content: { type: 'text', text }, + }); + } + /** * Starts the cron scheduler if cron is enabled and jobs exist. * The scheduler runs in the background, pushing fired prompts into @@ -802,10 +1055,12 @@ export class Session implements SessionContext { */ #startCronSchedulerIfNeeded(): void { if (!this.config.isCronEnabled()) return; + if (this.cronDisabledByTokenLimit) return; const scheduler = this.config.getCronScheduler(); if (scheduler.size === 0) return; scheduler.start((job: { prompt: string }) => { + if (this.cronDisabledByTokenLimit) return; this.cronQueue.push(job.prompt); void this.#drainCronQueue(); }); @@ -885,17 +1140,22 @@ export class Session implements SessionContext { null; const streamStartTime = Date.now(); - const responseStream = await this.config - .getGeminiClient()! - .getChat() - .sendMessageStream( - this.config.getModel(), - { - message: nextMessage.parts ?? [], - config: { abortSignal: ac.signal }, - }, - promptId, + const sendResult = await this.#sendMessageStreamWithAutoCompression( + promptId, + nextMessage.parts ?? [], + ac.signal, + ); + if (!sendResult.responseStream) { + this.#preserveUnsentMessageHistory( + nextMessage, + sendResult.stopReason === 'cancelled', ); + if (sendResult.stopReason === 'max_tokens') { + this.#stopCronAfterTokenLimit(); + } + return; + } + const responseStream = sendResult.responseStream; nextMessage = null; for await (const resp of responseStream) { @@ -933,6 +1193,7 @@ export class Session implements SessionContext { } if (usageMetadata) { + this.#recordPromptTokenCount(usageMetadata); // Kick off rewrite in background (non-blocking) if (this.messageRewriter) { this.messageRewriter.flushTurn(ac.signal); @@ -968,6 +1229,17 @@ export class Session implements SessionContext { ); } + #stopCronAfterTokenLimit(): void { + this.cronDisabledByTokenLimit = true; + this.cronQueue = []; + if (!this.config.isCronEnabled()) return; + this.config.getCronScheduler().stop(); + void this.#emitAgentDiagnosticMessageSafely( + 'Cron jobs disabled for the rest of this session due to token limit. Restart the session to re-enable.', + 'Failed to emit cron-disabled diagnostic', + ); + } + async sendAvailableCommandsUpdate(): Promise { const abortController = new AbortController(); try { diff --git a/packages/cli/src/i18n/locales/ca.js b/packages/cli/src/i18n/locales/ca.js index 7830aaaa6..3646f3746 100644 --- a/packages/cli/src/i18n/locales/ca.js +++ b/packages/cli/src/i18n/locales/ca.js @@ -1208,7 +1208,7 @@ export default { // ============================================================================ // Ordres - Model // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': 'Canviar el model per a aquesta sessió (--fast per al model de suggeriments)', 'Set a lighter model for prompt suggestions and speculative execution': 'Establir un model més lleuger per a suggeriments de missatges i execució especulativa', @@ -1217,6 +1217,8 @@ export default { 'Authentication type not available.': "Tipus d'autenticació no disponible.", 'No models available for the current authentication type ({{authType}}).': "No hi ha models disponibles per al tipus d'autenticació actual ({{authType}}).", + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Ordres - Netejar diff --git a/packages/cli/src/i18n/locales/de.js b/packages/cli/src/i18n/locales/de.js index 0f4164fcf..153f96b14 100644 --- a/packages/cli/src/i18n/locales/de.js +++ b/packages/cli/src/i18n/locales/de.js @@ -1031,7 +1031,7 @@ export default { // ============================================================================ // Commands - Model // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': 'Modell für diese Sitzung wechseln (--fast für Vorschlagsmodell)', 'Set a lighter model for prompt suggestions and speculative execution': 'Leichteres Modell für Eingabevorschläge und spekulative Ausführung festlegen', @@ -1041,6 +1041,8 @@ export default { 'Authentifizierungstyp nicht verfügbar.', 'No models available for the current authentication type ({{authType}}).': 'Keine Modelle für den aktuellen Authentifizierungstyp ({{authType}}) verfügbar.', + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Commands - Clear diff --git a/packages/cli/src/i18n/locales/en.js b/packages/cli/src/i18n/locales/en.js index 4415e0720..4ba34e373 100644 --- a/packages/cli/src/i18n/locales/en.js +++ b/packages/cli/src/i18n/locales/en.js @@ -1200,8 +1200,8 @@ export default { // ============================================================================ // Commands - Model // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': - 'Switch the model for this session (--fast for suggestion model)', + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).', 'Set a lighter model for prompt suggestions and speculative execution': 'Set a lighter model for prompt suggestions and speculative execution', 'Content generator configuration not available.': @@ -1209,6 +1209,8 @@ export default { 'Authentication type not available.': 'Authentication type not available.', 'No models available for the current authentication type ({{authType}}).': 'No models available for the current authentication type ({{authType}}).', + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Commands - Clear diff --git a/packages/cli/src/i18n/locales/fr.js b/packages/cli/src/i18n/locales/fr.js index 2f982c45f..45cb1337b 100644 --- a/packages/cli/src/i18n/locales/fr.js +++ b/packages/cli/src/i18n/locales/fr.js @@ -1175,7 +1175,7 @@ export default { // ============================================================================ // Commandes - Modèle // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': 'Changer le modèle pour cette session (--fast pour le modèle de suggestion)', 'Set a lighter model for prompt suggestions and speculative execution': "Définir un modèle plus léger pour les suggestions d'invite et l'exécution spéculative", @@ -1185,6 +1185,8 @@ export default { "Type d'authentification non disponible.", 'No models available for the current authentication type ({{authType}}).': "Aucun modèle disponible pour le type d'authentification actuel ({{authType}}).", + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Commandes - Effacer diff --git a/packages/cli/src/i18n/locales/ja.js b/packages/cli/src/i18n/locales/ja.js index 549fd3fbf..3faa5241a 100644 --- a/packages/cli/src/i18n/locales/ja.js +++ b/packages/cli/src/i18n/locales/ja.js @@ -785,7 +785,7 @@ export default { 'Failed to generate summary - no text content received from LLM response': 'サマリーの生成に失敗 - LLMレスポンスからテキストコンテンツを受信できませんでした', // Model - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': 'このセッションのモデルを切り替え(--fast で提案モデルを設定)', 'Set a lighter model for prompt suggestions and speculative execution': 'プロンプト提案と投機的実行用の軽量モデルを設定', @@ -794,6 +794,8 @@ export default { 'Authentication type not available.': '認証タイプが利用できません', 'No models available for the current authentication type ({{authType}}).': '現在の認証タイプ({{authType}})で利用可能なモデルはありません', + // Needs translation + ' (not in model registry)': ' (not in model registry)', // Clear 'Starting a new session, resetting chat, and clearing terminal.': '新しいセッションを開始し、チャットをリセットし、ターミナルをクリアしています', diff --git a/packages/cli/src/i18n/locales/pt.js b/packages/cli/src/i18n/locales/pt.js index c30d72b17..667e98fa2 100644 --- a/packages/cli/src/i18n/locales/pt.js +++ b/packages/cli/src/i18n/locales/pt.js @@ -1040,7 +1040,7 @@ export default { // ============================================================================ // Commands - Model // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': 'Trocar o modelo para esta sessão (--fast para modelo de sugestões)', 'Set a lighter model for prompt suggestions and speculative execution': 'Definir modelo mais leve para sugestões de prompt e execução especulativa', @@ -1049,6 +1049,8 @@ export default { 'Authentication type not available.': 'Tipo de autenticação não disponível.', 'No models available for the current authentication type ({{authType}}).': 'Nenhum modelo disponível para o tipo de autenticação atual ({{authType}}).', + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Commands - Clear diff --git a/packages/cli/src/i18n/locales/ru.js b/packages/cli/src/i18n/locales/ru.js index 8631a832f..f0e1a000a 100644 --- a/packages/cli/src/i18n/locales/ru.js +++ b/packages/cli/src/i18n/locales/ru.js @@ -1039,7 +1039,7 @@ export default { // ============================================================================ // Команды - Модель // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': 'Переключение модели для этой сессии (--fast для модели подсказок)', 'Set a lighter model for prompt suggestions and speculative execution': 'Установить облегчённую модель для подсказок и спекулятивного выполнения', @@ -1048,6 +1048,8 @@ export default { 'Authentication type not available.': 'Тип авторизации недоступен.', 'No models available for the current authentication type ({{authType}}).': 'Нет доступных моделей для текущего типа авторизации ({{authType}}).', + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Команды - Очистка diff --git a/packages/cli/src/i18n/locales/zh-TW.js b/packages/cli/src/i18n/locales/zh-TW.js index a86bce271..de32fdfc6 100644 --- a/packages/cli/src/i18n/locales/zh-TW.js +++ b/packages/cli/src/i18n/locales/zh-TW.js @@ -1027,7 +1027,7 @@ export default { 'Generating project summary...': '正在生成項目摘要...', 'Failed to generate summary - no text content received from LLM response': '生成摘要失敗 - 未從 LLM 響應中接收到文本內容', - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': '切換此會話的模型(--fast 可設置建議模型)', 'Set a lighter model for prompt suggestions and speculative execution': '設置用於輸入建議和推測執行的輕量模型', @@ -1035,6 +1035,8 @@ export default { 'Authentication type not available.': '認證類型不可用', 'No models available for the current authentication type ({{authType}}).': '當前認證類型 ({{authType}}) 沒有可用的模型', + // Needs translation + ' (not in model registry)': ' (not in model registry)', 'Starting a new session, resetting chat, and clearing terminal.': '正在開始新會話,重置聊天並清屏。', 'Starting a new session and clearing.': '正在開始新會話並清屏。', diff --git a/packages/cli/src/i18n/locales/zh.js b/packages/cli/src/i18n/locales/zh.js index ccdbec812..a1c4abb54 100644 --- a/packages/cli/src/i18n/locales/zh.js +++ b/packages/cli/src/i18n/locales/zh.js @@ -1142,7 +1142,7 @@ export default { // ============================================================================ // Commands - Model // ============================================================================ - 'Switch the model for this session (--fast for suggestion model)': + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).': '切换此会话的模型(--fast 可设置建议模型)', 'Set a lighter model for prompt suggestions and speculative execution': '设置用于输入建议和推测执行的轻量模型', @@ -1150,6 +1150,8 @@ export default { 'Authentication type not available.': '认证类型不可用', 'No models available for the current authentication type ({{authType}}).': '当前认证类型 ({{authType}}) 没有可用的模型', + // Needs translation + ' (not in model registry)': ' (not in model registry)', // ============================================================================ // Commands - Clear diff --git a/packages/cli/src/ui/commands/modelCommand.test.ts b/packages/cli/src/ui/commands/modelCommand.test.ts index 7da549519..6a6cf131c 100644 --- a/packages/cli/src/ui/commands/modelCommand.test.ts +++ b/packages/cli/src/ui/commands/modelCommand.test.ts @@ -34,7 +34,7 @@ describe('modelCommand', () => { it('should have the correct name and description', () => { expect(modelCommand.name).toBe('model'); expect(modelCommand.description).toBe( - 'Switch the model for this session (--fast for suggestion model)', + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).', ); }); diff --git a/packages/cli/src/ui/commands/modelCommand.ts b/packages/cli/src/ui/commands/modelCommand.ts index 1ee5e6083..c9e17afe0 100644 --- a/packages/cli/src/ui/commands/modelCommand.ts +++ b/packages/cli/src/ui/commands/modelCommand.ts @@ -14,15 +14,29 @@ import { CommandKind } from './types.js'; import { t } from '../../i18n/index.js'; import { getPersistScopeForModelSelection } from '../../config/modelProvidersScope.js'; +// Get an array of the available model IDs as strings +function getAvailableModelIds(context: CommandContext) { + const { services } = context; + const { config } = services; + if (!config) { + return []; + } + const availableModels = config.getAvailableModels(); + // Convert AvailableModel[] to string[] on AvailableModel.id + return availableModels.map((model) => model.id); +} + export const modelCommand: SlashCommand = { name: 'model', completionPriority: 100, get description() { - return t('Switch the model for this session (--fast for suggestion model)'); + return t( + 'Switch the model for this session (--fast for suggestion model, [model-id] to switch immediately).', + ); }, kind: CommandKind.BUILT_IN, supportedModes: ['interactive', 'non_interactive', 'acp'] as const, - completion: async (_context, partialArg) => { + completion: async (context, partialArg) => { if (partialArg && '--fast'.startsWith(partialArg)) { return [ { @@ -32,8 +46,14 @@ export const modelCommand: SlashCommand = { ), }, ]; + } else if (partialArg.trim()) { + // Include model IDs matching the partial argument + return getAvailableModelIds(context).filter((id) => + id.startsWith(partialArg.trim()), + ); + } else { + return null; } - return null; }, action: async ( context: CommandContext, @@ -110,10 +130,47 @@ export const modelCommand: SlashCommand = { }; } + // Handle modelName argument: immediately switch to the provided model + if (args !== '' && context.executionMode === 'interactive') { + const modelName = args.trim().split(' ')[0]; + if (modelName) { + // Use first argument only, avoids later syntax confusion and/or use of model names with spaces + // Ignore argument if it is empty, e.g. to avoid confusion with trailing whitespace + if (!settings) { + return { + type: 'message', + messageType: 'error', + content: t('Settings service not available.'), + }; + } + await config.setModel(modelName); + settings.setValue( + getPersistScopeForModelSelection(settings), + 'model.name', + modelName, + ); + + if (config.getModelsConfig().hasModel(authType, modelName)) { + return { + type: 'message', + messageType: 'info', + content: t('Model') + ': ' + modelName, + }; + } else { + return { + type: 'message', + messageType: 'info', + content: + t('Model') + ': ' + modelName + t(' (not in model registry)'), + }; + } + } + } + // Non-interactive/ACP: set model if an arg was provided, otherwise show current model if (context.executionMode !== 'interactive') { - const modelName = args.trim(); - if (modelName) { + const modelName = args.trim().split(' ')[0]; + if (modelName.trim()) { // /model — set the main model if (!settings) { return { @@ -122,12 +179,12 @@ export const modelCommand: SlashCommand = { content: t('Settings service not available.'), }; } + await config.setModel(modelName); settings.setValue( getPersistScopeForModelSelection(settings), 'model.name', modelName, ); - await config.setModel(modelName); return { type: 'message', messageType: 'info', diff --git a/packages/cli/src/ui/components/Footer.tsx b/packages/cli/src/ui/components/Footer.tsx index dde9c28c5..b0539f360 100644 --- a/packages/cli/src/ui/components/Footer.tsx +++ b/packages/cli/src/ui/components/Footer.tsx @@ -5,7 +5,6 @@ */ import type React from 'react'; -import { useCallback, useSyncExternalStore } from 'react'; import { Box, Text } from 'ink'; import { theme } from '../semantic-colors.js'; import { ContextUsageDisplay } from './ContextUsageDisplay.js'; @@ -25,39 +24,12 @@ import { ApprovalMode } from '@qwen-code/qwen-code-core'; import { GeminiSpinner } from './GeminiRespondingSpinner.js'; import { t } from '../../i18n/index.js'; -/** - * Returns true while any dream task for the current project is in - * 'pending' or 'running' state. Uses MemoryManager's subscribe/notify - * mechanism so there is zero polling overhead. - */ -function useDreamRunning(projectRoot: string): boolean { - const config = useConfig(); - - const subscribe = useCallback( - (onStoreChange: () => void) => - config.getMemoryManager().subscribe(onStoreChange), - [config], - ); - - const getSnapshot = useCallback( - () => - config - .getMemoryManager() - .listTasksByType('dream', projectRoot) - .some((task) => task.status === 'pending' || task.status === 'running'), - [config, projectRoot], - ); - - return useSyncExternalStore(subscribe, getSnapshot); -} - export const Footer: React.FC = () => { const uiState = useUIState(); const config = useConfig(); const { vimEnabled, vimMode } = useVimMode(); const { lines: statusLineLines } = useStatusLine(); const configInitMessage = useConfigInitMessage(uiState.isConfigInitialized); - const dreamRunning = useDreamRunning(config.getProjectRoot()); const { promptTokenCount, showAutoAcceptIndicator } = { promptTokenCount: uiState.sessionStats.lastPromptTokenCount, @@ -134,12 +106,10 @@ export const Footer: React.FC = () => { node: Debug Mode, }); } - if (dreamRunning) { - rightItems.push({ - key: 'dream', - node: {t('✦ dreaming')}, - }); - } + // Dream tasks now surface via the BackgroundTasksPill (e.g. "1 dream") + // alongside the other background-task kinds. The previous `✦ dreaming` + // right-column indicator was removed to avoid two simultaneous signals + // for the same underlying state. if (promptTokenCount > 0 && contextWindowSize) { rightItems.push({ key: 'context', diff --git a/packages/cli/src/ui/components/HistoryItemDisplay.tsx b/packages/cli/src/ui/components/HistoryItemDisplay.tsx index 5344da6a1..4c00e7b2a 100644 --- a/packages/cli/src/ui/components/HistoryItemDisplay.tsx +++ b/packages/cli/src/ui/components/HistoryItemDisplay.tsx @@ -210,6 +210,7 @@ const HistoryItemDisplayComponent: React.FC = ({ availableTerminalHeight={availableTerminalHeight} contentWidth={contentWidth} isFocused={isFocused} + isPending={isPending} activeShellPtyId={activeShellPtyId} embeddedShellFocused={embeddedShellFocused} memoryWriteCount={itemForDisplay.memoryWriteCount} diff --git a/packages/cli/src/ui/components/InputPrompt.test.tsx b/packages/cli/src/ui/components/InputPrompt.test.tsx index c01c2a54e..ffbb8d34c 100644 --- a/packages/cli/src/ui/components/InputPrompt.test.tsx +++ b/packages/cli/src/ui/components/InputPrompt.test.tsx @@ -8,7 +8,7 @@ import { renderWithProviders } from '../../test-utils/render.js'; import { waitFor, act } from '@testing-library/react'; import type { InputPromptProps } from './InputPrompt.js'; import { InputPrompt } from './InputPrompt.js'; -import type { TextBuffer } from './shared/text-buffer.js'; +import { useTextBuffer, type TextBuffer } from './shared/text-buffer.js'; import type { Config } from '@qwen-code/qwen-code-core'; import { ApprovalMode } from '@qwen-code/qwen-code-core'; import * as path from 'node:path'; @@ -78,6 +78,38 @@ const mockSlashCommands: SlashCommand[] = [ }, ], }, + { + name: 'export', + kind: CommandKind.BUILT_IN, + description: 'Export session', + action: vi.fn(), + subCommands: [ + { + name: 'html', + kind: CommandKind.BUILT_IN, + description: 'Export HTML', + action: vi.fn(), + }, + { + name: 'md', + kind: CommandKind.BUILT_IN, + description: 'Export Markdown', + action: vi.fn(), + }, + { + name: 'json', + kind: CommandKind.BUILT_IN, + description: 'Export JSON', + action: vi.fn(), + }, + { + name: 'jsonl', + kind: CommandKind.BUILT_IN, + description: 'Export JSONL', + action: vi.fn(), + }, + ], + }, ]; describe('InputPrompt', () => { @@ -725,6 +757,747 @@ describe('InputPrompt', () => { unmount(); }); + it('should submit a perfect match on Enter when suggestions were not navigated', async () => { + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\r'); + await wait(); + + expect(props.onSubmit).toHaveBeenCalledWith('/export'); + expect(mockCommandCompletion.handleAutocomplete).not.toHaveBeenCalled(); + unmount(); + }); + + it('should fill and submit an export format selected with arrow navigation', async () => { + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); + await wait(); + + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + expect(mockCommandCompletion.handleAutocomplete).not.toHaveBeenCalled(); + expect(props.onSubmit).not.toHaveBeenCalled(); + + stdin.write('\r'); + await wait(); + + expect(props.onSubmit).toHaveBeenCalledWith('/export md'); + unmount(); + }); + + it('should keep cycling export formats after arrow navigation fills input', async () => { + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); + await wait(); + stdin.write('\u001B[B'); + await wait(); + + expect(props.buffer.setText).toHaveBeenNthCalledWith(2, '/export md'); + expect(props.buffer.setText).toHaveBeenNthCalledWith(3, '/export json'); + expect(mockInputHistory.navigateDown).not.toHaveBeenCalled(); + unmount(); + }); + + it('should keep export format suggestions visible after arrow navigation fills input', async () => { + const exportSuggestions = [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ]; + mockedUseCommandCompletion.mockImplementation((buffer) => { + const isExportRoot = buffer.text.trim() === '/export'; + return { + ...mockCommandCompletion, + showSuggestions: isExportRoot, + suggestions: isExportRoot ? exportSuggestions : [], + activeSuggestionIndex: 0, + isPerfectMatch: isExportRoot, + }; + }); + const TestHarness = () => { + const buffer = useTextBuffer({ + initialText: '/export', + viewport: { width: 80, height: 20 }, + isValidPath: () => false, + onChange: () => {}, + }); + return ; + }; + + const { stdin, lastFrame, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); + await wait(); + + const output = stripAnsi(lastFrame() ?? ''); + expect(output).toContain('/export md'); + expect(output).toContain('html'); + expect(output).toContain('md'); + expect(output).toContain('json'); + expect(output).toContain('jsonl'); + expect(output).toContain('Export Markdown'); + unmount(); + }); + + it('should not clobber manually edited buffer when arrow is pressed after export fill', async () => { + // Regression for PR #3701 review: exportCompletionSelectionIndexRef + // leaked across buffer edits, so arrow keys would overwrite user-typed + // text after the user moved away from an "/export " input. + mockedUseCommandCompletion.mockImplementation((buffer) => { + const isExportRoot = buffer.text.trim() === '/export'; + return { + ...mockCommandCompletion, + showSuggestions: isExportRoot, + suggestions: isExportRoot + ? [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ] + : [], + activeSuggestionIndex: 0, + isPerfectMatch: isExportRoot, + }; + }); + + const TestHarness = () => { + const buffer = useTextBuffer({ + initialText: '/export', + viewport: { width: 80, height: 20 }, + isValidPath: () => false, + onChange: () => {}, + }); + return ; + }; + + const { stdin, lastFrame, unmount } = renderWithProviders(); + await wait(); + + // Phase 1 + 2: Down fills "/export md". + stdin.write('\u001B[B'); + await wait(); + expect(stripAnsi(lastFrame() ?? '')).toContain('/export md'); + + // User clears buffer and types a different command manually. + stdin.write('\u0015'); // Ctrl+U: clear line + await wait(); + // Pin the intermediate state: Ctrl+U must actually clear the buffer + // before we type the new command, so a future useTextBuffer/hook change + // can't make this test pass for the wrong reason. + expect(stripAnsi(lastFrame() ?? '')).not.toContain('/export'); + stdin.write('/help'); + await wait(); + const afterEditFrame = stripAnsi(lastFrame() ?? ''); + expect(afterEditFrame).toContain('/help'); + + // Pressing Down now must NOT overwrite "/help" with an export format. + stdin.write('\u001B[B'); + await wait(); + const afterArrowFrame = stripAnsi(lastFrame() ?? ''); + expect(afterArrowFrame).toContain('/help'); + expect(afterArrowFrame).not.toMatch(/\/export\s+(html|md|json|jsonl)/); + unmount(); + }); + + it('should wrap to jsonl when pressing Up from the /export Phase 1 popup', async () => { + // Regression for PR #3701 second-round review: Phase 1 Up-arrow path + // (including 0 -> lastIndex wrap) had no test coverage. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[A'); + await wait(); + + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export jsonl'); + expect(mockCommandCompletion.handleAutocomplete).not.toHaveBeenCalled(); + unmount(); + }); + + it('should wrap Phase 2 cycling backward when pressing Up repeatedly', async () => { + // Regression for PR #3701 second-round review: Phase 2 Up-arrow wrap + // logic had no test coverage (existing tests only used Down). + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Phase 1 Down -> /export md (ref=1). + stdin.write('\u001B[B'); + await wait(); + // Phase 2 Up -> /export html (ref=0). + stdin.write('\u001B[A'); + await wait(); + // Phase 2 Up wraps from index 0 to last index -> /export jsonl (ref=3). + stdin.write('\u001B[A'); + await wait(); + + expect(props.buffer.setText).toHaveBeenNthCalledWith(2, '/export md'); + expect(props.buffer.setText).toHaveBeenNthCalledWith(3, '/export html'); + expect(props.buffer.setText).toHaveBeenNthCalledWith(4, '/export jsonl'); + unmount(); + }); + + it('should seed Phase 2 cycling when Tab accepts a format in the /export popup', async () => { + // Regression for PR #3701 second-round review (Suggestion): Tab in the + // Phase 1 popup must run the export-specific path so that + // exportCompletionSelectionIndexRef is seeded and subsequent arrow/Tab + // keys can continue cycling. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Tab in Phase 1 popup fills /export html and seeds the ref. + stdin.write('\t'); + await wait(); + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export html'); + expect(mockCommandCompletion.handleAutocomplete).not.toHaveBeenCalled(); + + // Phase 2 Down now cycles forward from the seeded ref. + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + + // Phase 2 Tab should also cycle (covers isCompletionTabKey branch). + stdin.write('\t'); + await wait(); + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export json'); + unmount(); + }); + + it('should not overwrite /export html with extra args when Down is pressed', async () => { + // Regression for PR #3701 second-round review (Critical): Phase 2 cycling + // guard used startsWith('/export '), which matched inputs like + // '/export html --verbose' and silently wiped out the extra arguments. + // The strict getExportFormatFromInput guard must let such inputs fall + // through without overwriting the buffer. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Seed Phase 2 state: Down fills /export md and sets ref=1. + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + + // Simulate the user appending extra arguments to the export input. + props.buffer.setText('/export md --verbose'); + (props.buffer.setText as ReturnType).mockClear(); + + // Pressing Down must NOT replace the buffer with '/export json'. + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).not.toHaveBeenCalled(); + unmount(); + }); + + it('should reset export cycling state on Escape so arrows no longer cycle', async () => { + // Regression for PR #3701 third-round review (Suggestion): ESC resets + // exportCompletionSelectionIndexRef but this path had no test coverage, + // so a regression could silently break the reset. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Phase 1 Down -> /export md (enters Phase 2). + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + (props.buffer.setText as ReturnType).mockClear(); + + // Press Escape — should reset the cycling state. + stdin.write('\x1B'); + await wait(); + + // Subsequent Down must NOT overwrite the buffer with an export format. + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export json'); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export html'); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export jsonl'); + expect(props.buffer.setText).not.toHaveBeenCalled(); + unmount(); + }); + + it('should reset export cycling state on Ctrl+C so new input is not overwritten', async () => { + // Regression for PR #3701 third-round review (Suggestion): Ctrl+C resets + // exportCompletionSelectionIndexRef but this path had no test coverage, + // so the ref could leak into the new input. Verify that after Ctrl+C + // the user can type a completely unrelated command without arrow keys + // clobbering it with export formats. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Phase 1 Down -> /export md (enters Phase 2). + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + (props.buffer.setText as ReturnType).mockClear(); + + // Ctrl+C clears the buffer. + stdin.write('\x03'); + await wait(); + + // Set a completely different command into the buffer. + props.buffer.setText('/help'); + (props.buffer.setText as ReturnType).mockClear(); + + // Pressing Down must NOT overwrite '/help' with an export format. + stdin.write('\u001B[B'); + await wait(); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export json'); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export html'); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export jsonl'); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export md'); + expect(props.buffer.setText).not.toHaveBeenCalled(); + unmount(); + }); + + it('should cycle export format on Down when /export was typed manually (not via popup)', async () => { + // Regression for PR #3701 fifth-round review: users who type + // "/export md" directly (without going through the Phase-1 popup) + // must still get Phase-2 cycling on arrow keys, but programmatic + // buffer.setText/history restores must not arm the same state. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: false, // popup is closed for direct input + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: false, + }); + + const TestHarness = () => { + const buffer = useTextBuffer({ + initialText: '', + viewport: { width: 80, height: 20 }, + isValidPath: () => false, + onChange: () => {}, + }); + return ; + }; + + const { stdin, lastFrame, unmount } = renderWithProviders(); + await wait(); + + stdin.write('/export md'); + await wait(); + expect(stripAnsi(lastFrame() ?? '')).toContain('/export md'); + + // Pressing Down must cycle to the NEXT format (json). + stdin.write('\u001B[B'); + await wait(); + expect(stripAnsi(lastFrame() ?? '')).toContain('/export json'); + unmount(); + }); + + it('should not arm export cycling from restored history text', async () => { + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: false, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: false, + }); + + const TestHarness = () => { + const buffer = useTextBuffer({ + initialText: '/export md', + viewport: { width: 80, height: 20 }, + isValidPath: () => false, + onChange: () => {}, + }); + return ; + }; + + const { stdin, lastFrame, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); + await wait(); + + expect(mockInputHistory.navigateDown).toHaveBeenCalled(); + expect(stripAnsi(lastFrame() ?? '')).toContain('/export md'); + expect(stripAnsi(lastFrame() ?? '')).not.toContain('/export json'); + unmount(); + }); + + it('should trigger export-specific arrow navigation even when completion suggestions are a superset', async () => { + // Regression for PR #3701 review (hasExportFormatSuggestions superset + // matching): when extra non-export items appear alongside all export + // formats, Phase 1 must still route arrow keys to setExportCompletionInput + // instead of silently falling through to generic navigateDown. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + { label: 'report', value: 'report' }, // extra suggestion + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); // Down + await wait(); + + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + expect(mockCommandCompletion.navigateDown).not.toHaveBeenCalled(); + unmount(); + }); + + it('should fall through to generic accept when Tab targets a non-export item in the /export superset popup', async () => { + // Regression for PR #3701 review: when the active suggestion in the + // Phase 1 superset popup is a non-export item, ACCEPT_SUGGESTION must + // NOT call setExportCompletionInput — it must fall through to the + // generic acceptActiveCompletionSuggestion. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + { label: 'report', value: 'report' }, + ], + activeSuggestionIndex: 4, // the non-export item + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\t'); // Tab (ACCEPT_SUGGESTION) + await wait(); + + expect(mockCommandCompletion.handleAutocomplete).toHaveBeenCalledWith(4); + // Must NOT write "/export report" to the buffer. + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export report'); + unmount(); + }); + + it('should fall through to generic completion when suggestions are missing an export format', async () => { + // Regression for PR #3701 review: when completion suggestions do NOT + // include all export formats, hasExportFormatSuggestions must be false + // and arrow keys must go through the generic completion.navigateDown + // path instead of setExportCompletionInput. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + // jsonl intentionally missing + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); // Down + await wait(); + + expect(mockCommandCompletion.navigateDown).toHaveBeenCalled(); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export md'); + expect(props.buffer.setText).not.toHaveBeenCalledWith('/export json'); + unmount(); + }); + + it('should trigger Phase 1 export popup even when /export has trailing spaces', async () => { + // Regression: trailing whitespace after "/export" must be treated the + // same as plain "/export" — trim() should normalise the buffer so + // hasExportFormatSuggestions activates and the popup triggers. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/export '); // trailing spaces + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\u001B[B'); // Down + await wait(); + + expect(props.buffer.setText).toHaveBeenLastCalledWith('/export md'); + expect(mockCommandCompletion.navigateDown).not.toHaveBeenCalled(); + unmount(); + }); + + it('should autocomplete on Enter when user arrow-navigated a perfect-match suggestion list', async () => { + // Regression for PR #3701 review: the isPerfectMatch + navigated + Enter + // branch was not covered by tests. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'show', value: 'show' }, + { label: 'add', value: 'add' }, + { label: 'refresh', value: 'refresh' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/memory'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Arrow-navigate so completionSelectionWasNavigatedRef flips to true. + stdin.write('\u001B[B'); + await wait(); + expect(mockCommandCompletion.navigateDown).toHaveBeenCalled(); + + // Enter should autocomplete the active suggestion, NOT submit the raw buffer. + stdin.write('\r'); + await wait(); + + expect(mockCommandCompletion.handleAutocomplete).toHaveBeenCalledWith(0); + expect(props.onSubmit).not.toHaveBeenCalled(); + unmount(); + }); + + it('should submit directly on Enter after arrow-navigate + backspace + retype to perfect match', async () => { + // Regression for PR #3701 review (issue #5): navigate → backspace + // → retype → Enter must submit the raw buffer, not autocomplete. + // If the popup persists across backspace+retype and the navigated flag + // is not cleared on buffer.text changes, Enter would autocomplete the + // first sub-command instead of submitting the perfect match. + mockedUseCommandCompletion.mockImplementation((buf) => { + const text = buf.text; + const isMemory = text === '/memory'; + return { + ...mockCommandCompletion, + showSuggestions: true, // popup stays visible throughout + suggestions: [ + { label: 'show', value: 'show' }, + { label: 'add', value: 'add' }, + { label: 'refresh', value: 'refresh' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: isMemory, + }; + }); + props.buffer.setText('/memory'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + // Arrow-navigate so completionSelectionWasNavigatedRef flips to true. + stdin.write('\u001B[B'); + await wait(); + expect(mockCommandCompletion.navigateDown).toHaveBeenCalled(); + + // Simulate backspace to /memor — popup stays visible in this test. + // Use unmount + re-render to force useEffect re-evaluation on the + // changed buffer.text (direct setText on the mock object doesn't + // trigger React state updates, so effects don't fire). + props.buffer.setText('/memor'); + unmount(); + const { unmount: unmountAfterEdit } = renderWithProviders( + , + ); + await wait(); + + // Retype: /memor → /memory. + props.buffer.setText('/memory'); + unmountAfterEdit(); + const { stdin: stdinFinal, unmount: unmountFinal } = renderWithProviders( + , + ); + await wait(); + + // Enter must submit '/memory', NOT autocomplete 'show'. + stdinFinal.write('\r'); + await wait(); + + expect(props.onSubmit).toHaveBeenCalledWith('/memory'); + expect(mockCommandCompletion.handleAutocomplete).not.toHaveBeenCalled(); + unmountFinal(); + }); + + it('should submit directly on Enter for a perfect match without prior arrow navigation', async () => { + // Control: with no arrow navigation, Enter on a perfect match must submit. + mockedUseCommandCompletion.mockReturnValue({ + ...mockCommandCompletion, + showSuggestions: true, + suggestions: [ + { label: 'show', value: 'show' }, + { label: 'add', value: 'add' }, + ], + activeSuggestionIndex: 0, + isPerfectMatch: true, + }); + props.buffer.setText('/memory'); + + const { stdin, unmount } = renderWithProviders(); + await wait(); + + stdin.write('\r'); + await wait(); + + expect(props.onSubmit).toHaveBeenCalledWith('/memory'); + expect(mockCommandCompletion.handleAutocomplete).not.toHaveBeenCalled(); + unmount(); + }); + it('should reset history navigation after submitting on Enter', async () => { mockedUseCommandCompletion.mockReturnValue({ ...mockCommandCompletion, @@ -2074,6 +2847,66 @@ describe('InputPrompt', () => { unmount(); }); + it('shows command search suggestions over active export suggestions', async () => { + props.shellModeActive = false; + const exportSuggestions = [ + { label: 'html', value: 'html' }, + { label: 'md', value: 'md' }, + { label: 'json', value: 'json' }, + { label: 'jsonl', value: 'jsonl' }, + ]; + + mockedUseCommandCompletion.mockImplementation((buffer) => { + const isExportRoot = buffer.text.trim() === '/export'; + return { + ...mockCommandCompletion, + showSuggestions: isExportRoot, + suggestions: isExportRoot ? exportSuggestions : [], + activeSuggestionIndex: 0, + isPerfectMatch: isExportRoot, + }; + }); + vi.mocked(useReverseSearchCompletion).mockImplementation( + (_buffer, _data, isActive) => ({ + ...mockReverseSearchCompletion, + suggestions: isActive + ? [{ label: 'git status', value: 'git status' }] + : [], + showSuggestions: !!isActive, + activeSuggestionIndex: isActive ? 0 : -1, + }), + ); + + const TestHarness = () => { + const buffer = useTextBuffer({ + initialText: '/export', + viewport: { width: 80, height: 20 }, + isValidPath: () => false, + onChange: () => {}, + }); + return ; + }; + + const { stdin, lastFrame, unmount } = renderWithProviders( + , + ); + await wait(); + + stdin.write('\u001B[B'); + await wait(); + expect(stripAnsi(lastFrame() ?? '')).toContain('/export md'); + expect(stripAnsi(lastFrame() ?? '')).toContain('jsonl'); + + stdin.write('\x12'); // Ctrl+R + await wait(); + + const frame = stripAnsi(lastFrame() ?? ''); + expect(frame).toContain('(r:)'); + expect(frame).toContain('git status'); + expect(frame).not.toContain('jsonl'); + unmount(); + }); + it.skip('expands and collapses long suggestion via Right/Left arrows', async () => { props.shellModeActive = false; const longValue = 'l'.repeat(200); diff --git a/packages/cli/src/ui/components/InputPrompt.tsx b/packages/cli/src/ui/components/InputPrompt.tsx index 9e8eb31c9..684ae69d8 100644 --- a/packages/cli/src/ui/components/InputPrompt.tsx +++ b/packages/cli/src/ui/components/InputPrompt.tsx @@ -17,6 +17,7 @@ import chalk from 'chalk'; import { useShellHistory } from '../hooks/useShellHistory.js'; import { useReverseSearchCompletion } from '../hooks/useReverseSearchCompletion.js'; import { useCommandCompletion } from '../hooks/useCommandCompletion.js'; +import { useExportCompletion } from '../hooks/useExportCompletion.js'; import { useFollowupSuggestionsCLI } from '../hooks/useFollowupSuggestions.js'; import type { Config } from '@qwen-code/qwen-code-core'; import type { Key } from '../hooks/useKeypress.js'; @@ -65,6 +66,7 @@ export interface Attachment { } const debugLogger = createDebugLogger('INPUT_PROMPT'); + export interface InputPromptProps { buffer: TextBuffer; onSubmit: (value: string) => void; @@ -189,6 +191,7 @@ export const InputPrompt: React.FC = ({ ]); const [expandedSuggestionIndex, setExpandedSuggestionIndex] = useState(-1); + const exportCompletion = useExportCompletion(buffer, slashCommands); const shellHistory = useShellHistory(config.getProjectRoot()); const shellHistoryData = shellHistory.history; @@ -302,6 +305,7 @@ export const InputPrompt: React.FC = ({ const handleSubmitAndClear = useCallback( (submittedValue: string) => { + exportCompletion.reset(); // Expand any large paste placeholders to their full content before submitting let finalValue = submittedValue; if (pendingPastes.size > 0) { @@ -353,6 +357,7 @@ export const InputPrompt: React.FC = ({ resetReverseSearchCompletionState(); }, [ + exportCompletion, onSubmit, buffer, resetCompletionState, @@ -615,6 +620,7 @@ export const InputPrompt: React.FC = ({ } if (keyMatchers[Command.ESCAPE](key)) { + exportCompletion.reset(); const cancelSearch = ( setActive: (active: boolean) => void, resetCompletion: () => void, @@ -780,8 +786,41 @@ export const InputPrompt: React.FC = ({ } } + // Export-specific arrow/Tab/Enter handling (Phase 1 + Phase 2). + if (exportCompletion.handleExportInput(key, completion)) { + return true; + } + + const acceptActiveCompletionSuggestion = () => { + if (completion.suggestions.length === 0) { + return false; + } + + const targetIndex = + completion.activeSuggestionIndex === -1 + ? 0 + : completion.activeSuggestionIndex; + if (targetIndex >= completion.suggestions.length) { + return false; + } + + completion.handleAutocomplete(targetIndex); + exportCompletion.navigatedRef.current = false; + setExpandedSuggestionIndex(-1); + return true; + }; + // If the command is a perfect match, pressing enter should execute it. if (completion.isPerfectMatch && keyMatchers[Command.RETURN](key)) { + if ( + completion.showSuggestions && + exportCompletion.navigatedRef.current && + exportCompletion.navigatedTextRef.current === buffer.text && + acceptActiveCompletionSuggestion() + ) { + return true; + } + handleSubmitAndClear(buffer.text); return true; } @@ -819,29 +858,26 @@ export const InputPrompt: React.FC = ({ if (completion.showSuggestions) { if (completion.suggestions.length > 1) { - if (keyMatchers[Command.COMPLETION_UP](key)) { + const isCompletionUpKey = keyMatchers[Command.COMPLETION_UP](key); + const isCompletionDownKey = keyMatchers[Command.COMPLETION_DOWN](key); + if (isCompletionUpKey) { completion.navigateUp(); - setExpandedSuggestionIndex(-1); // Reset expansion when navigating + exportCompletion.navigatedRef.current = true; + exportCompletion.navigatedTextRef.current = buffer.text; + setExpandedSuggestionIndex(-1); return true; } - if (keyMatchers[Command.COMPLETION_DOWN](key)) { + if (isCompletionDownKey) { completion.navigateDown(); - setExpandedSuggestionIndex(-1); // Reset expansion when navigating + exportCompletion.navigatedRef.current = true; + exportCompletion.navigatedTextRef.current = buffer.text; + setExpandedSuggestionIndex(-1); return true; } } if (keyMatchers[Command.ACCEPT_SUGGESTION](key) && !key.paste) { - if (completion.suggestions.length > 0) { - const targetIndex = - completion.activeSuggestionIndex === -1 - ? 0 // Default to the first if none is active - : completion.activeSuggestionIndex; - if (targetIndex < completion.suggestions.length) { - completion.handleAutocomplete(targetIndex); - setExpandedSuggestionIndex(-1); // Reset expansion after selection - } - } + acceptActiveCompletionSuggestion(); return true; } } @@ -1068,8 +1104,16 @@ export const InputPrompt: React.FC = ({ // No placeholder matched — fall through to BaseTextInput's default backspace } + // Ctrl+U (clear-line) — reset export cycling state so a subsequent + // manual typing of "/export " doesn't mistakenly show the + // persistent suggestion panel as if the user had cycled. + if (key.ctrl && key.name === 'u') { + exportCompletion.reset(); + } + // Ctrl+C with completion active — also reset completion state if (keyMatchers[Command.CLEAR_INPUT](key)) { + exportCompletion.reset(); if (buffer.text.length > 0) { resetCompletionState(); } @@ -1090,6 +1134,23 @@ export const InputPrompt: React.FC = ({ followup.dismiss(); onPromptSuggestionDismiss?.(); } + + if ( + !key.ctrl && + !key.meta && + !key.paste && + ((key.sequence && key.sequence.length === 1) || + key.name === 'backspace' || + key.name === 'delete') + ) { + exportCompletion.markNextTextChangeAsUserInput(); + } + // NOTE: the former unconditional + // `exportCompletion.reset();` + // at this fallthrough was removed — the phase-2 buffer-text guard above + // already prevents stale state from affecting non-/export input, and + // the blanket reset was wiping selection on cursor-only keys such as + // Home / End / Ctrl+A. return false; }, [ @@ -1137,6 +1198,7 @@ export const InputPrompt: React.FC = ({ setBgPillFocused, followup, onPromptSuggestionDismiss, + exportCompletion, ], ); @@ -1255,7 +1317,20 @@ export const InputPrompt: React.FC = ({ }; const activeCompletion = getActiveCompletion(); - const shouldShowSuggestions = activeCompletion.showSuggestions; + const shouldUseExportSuggestions = + !commandSearchActive && !reverseSearchActive; + const suggestionDisplayProps = + shouldUseExportSuggestions && exportCompletion.suggestionDisplayProps + ? exportCompletion.suggestionDisplayProps + : { + suggestions: activeCompletion.suggestions, + activeIndex: activeCompletion.activeSuggestionIndex, + isLoading: activeCompletion.isLoadingSuggestions, + scrollOffset: activeCompletion.visibleStartIndex, + }; + const shouldShowSuggestions = + (shouldUseExportSuggestions && exportCompletion.shouldShowSuggestions) || + activeCompletion.showSuggestions; // Notify parent about suggestions visibility changes useEffect(() => { @@ -1354,11 +1429,11 @@ export const InputPrompt: React.FC = ({ {shouldShowSuggestions && ( ({ return entry.shellId; case 'monitor': return entry.monitorId; + case 'dream': + return entry.dreamId; default: { const _exhaustive: never = entry; throw new Error( @@ -53,7 +56,7 @@ vi.mock('../../hooks/useKeypress.js', () => ({ const mockedUseBackgroundTaskView = vi.mocked(useBackgroundTaskView); const mockedUseKeypress = vi.mocked(useKeypress); -function entry(overrides: Partial = {}): DialogEntry { +function entry(overrides: Partial = {}): AgentDialogEntry { return { kind: 'agent', agentId: 'a', @@ -62,7 +65,21 @@ function entry(overrides: Partial = {}): DialogEntry { startTime: 0, abortController: new AbortController(), ...overrides, - } as DialogEntry; + } as AgentDialogEntry; +} + +function dreamEntry( + overrides: Partial = {}, +): DreamDialogEntry { + return { + kind: 'dream', + dreamId: 'd-1', + status: 'running', + startTime: 0, + sessionCount: 7, + progressText: 'Scheduled managed auto-memory dream.', + ...overrides, + }; } function monitorEntry(overrides: Partial = {}): DialogEntry { @@ -94,6 +111,7 @@ interface Harness { resume: ReturnType; abandon: ReturnType; monitorCancel: ReturnType; + dreamCancelTask: ReturnType; setEntries: (next: readonly DialogEntry[]) => void; pressKey: (key: { name?: string; sequence?: string }) => void; call: (fn: () => void) => void; @@ -112,6 +130,7 @@ function setup(initial: readonly DialogEntry[]): Harness { const resume = vi.fn(); const abandon = vi.fn(); const monitorCancel = vi.fn(); + const dreamCancelTask = vi.fn(); // Stub registry that resolves `.get(agentId)` against the current entries // snapshot — the dialog now re-reads agent entries via `.get()` to pick up // live activity/stats mutations the snapshot misses. @@ -138,6 +157,9 @@ function setup(initial: readonly DialogEntry[]): Harness { return match; }, }), + getMemoryManager: () => ({ + cancelTask: dreamCancelTask, + }), resumeBackgroundAgent: resume, abandonBackgroundAgent: abandon, } as unknown as Config; @@ -182,14 +204,21 @@ function setup(initial: readonly DialogEntry[]): Harness { resume, abandon, monitorCancel, + dreamCancelTask, setEntries(next) { handlers.length = 0; currentEntries = next; act(() => handle.current!.setEntries(next)); }, pressKey(key) { + // Real `useKeypress` unbinds the previous callback on rerender, so + // only the most recently registered closure should run. Calling all + // accumulated handlers misses state updates that happened between + // renders (the older closures see stale state) — the symptom looks + // like a re-render race in production code that doesn't exist. act(() => { - for (const h of handlers) h(key); + const latest = handlers[handlers.length - 1]; + if (latest) latest(key); }); }, call(fn) { @@ -273,6 +302,114 @@ describe('BackgroundTasksDialog', () => { expect(h.probe.current!.state.dialogMode).toBe('detail'); }); + it('foreground cancel requires two `x` presses to confirm (one-press is a no-op)', () => { + // Foreground entries block the parent's tool-call: cancelling one ends + // the current turn with a partial result for that subagent. The dialog + // gates the destructive action behind a confirm step so the user can't + // wipe out their turn with a stray keypress. + const fg = entry({ + agentId: 'fg-1', + status: 'running', + flavor: 'foreground', + }); + const h = setup([fg]); + + h.call(() => h.probe.current!.actions.openDialog()); + + h.pressKey({ sequence: 'x' }); + expect(h.cancel).not.toHaveBeenCalled(); + + h.pressKey({ sequence: 'x' }); + expect(h.cancel).toHaveBeenCalledWith('fg-1'); + }); + + it('background cancel still fires on the first `x` press (no confirm)', () => { + // Backwards compatibility: the existing background-only cancel UX + // stays one-shot. Adding a confirm there would regress every workflow + // that relies on quickly cancelling a long-running async agent. + const bg = entry({ + agentId: 'bg-1', + status: 'running', + flavor: 'background', + }); + const h = setup([bg]); + + h.call(() => h.probe.current!.actions.openDialog()); + + h.pressKey({ sequence: 'x' }); + expect(h.cancel).toHaveBeenCalledWith('bg-1'); + }); + + it('ignores `x` on a terminal foreground entry (no arm, no cancel call)', () => { + // A foreground entry briefly stays visible after settling but before + // the tool-call's finally path unregisters it. The dialog's hint + // footer drops "x stop" once status leaves 'running', but without + // gating handleCancelKey itself, the first `x` would still arm a + // confirm step on the (now-terminal) entry — surfacing a misleading + // "x again to confirm stop" line that does nothing. + const completed = entry({ + agentId: 'fg-done', + status: 'completed', + flavor: 'foreground', + }); + const h = setup([completed]); + + h.call(() => h.probe.current!.actions.openDialog()); + + h.pressKey({ sequence: 'x' }); + expect(h.lastFrame()).not.toContain('x again to confirm stop'); + + h.pressKey({ sequence: 'x' }); + expect(h.cancel).not.toHaveBeenCalled(); + }); + + it('detail-mode left clears any armed foreground cancel before exiting', () => { + // Detail-mode `x` arms the foreground confirm step on the focused + // entry. If the user presses `left` to back out without confirming, + // the armed state must NOT carry into list mode — otherwise the + // hint bar still shows "x again to confirm stop" and the next `x` + // unintentionally cancels the run. + const fg = entry({ + agentId: 'fg-1', + status: 'running', + flavor: 'foreground', + }); + const h = setup([fg]); + + h.call(() => h.probe.current!.actions.openDialog()); + h.call(() => h.probe.current!.actions.enterDetail()); + + h.pressKey({ sequence: 'x' }); + h.pressKey({ name: 'left' }); + expect(h.probe.current!.state.dialogMode).toBe('list'); + + // Back in list mode, the next `x` arms again rather than confirming + // a stale armed state inherited from detail mode. + h.pressKey({ sequence: 'x' }); + expect(h.cancel).not.toHaveBeenCalled(); + }); + + it('Esc backs out of an armed foreground cancel without closing the dialog', () => { + const fg = entry({ + agentId: 'fg-1', + status: 'running', + flavor: 'foreground', + }); + const h = setup([fg]); + + h.call(() => h.probe.current!.actions.openDialog()); + + h.pressKey({ sequence: 'x' }); + h.pressKey({ name: 'escape' }); + // Dialog still open — Esc on the armed cancel resets the confirm + // state instead of nuking the dialog. + expect(h.probe.current!.state.dialogOpen).toBe(true); + + // After the Esc reset, the next `x` arms again rather than confirming. + h.pressKey({ sequence: 'x' }); + expect(h.cancel).not.toHaveBeenCalled(); + }); + it('clamps selectedIndex when entries shrink', () => { const a = entry({ agentId: 'a' }); const b = entry({ agentId: 'b' }); @@ -447,4 +584,119 @@ describe('BackgroundTasksDialog', () => { expect(f).not.toContain('Stopped because'); }); }); + + describe('dream entries', () => { + // Coverage for the dream task kind in the unified pill / dialog + // plumbing — list rendering, detail body, hint visibility, and + // cancellation routing. Mirrors the agent / shell / monitor + // coverage profile so each kind has parity in this test file. + it('renders the [dream] row with session count in list mode', () => { + const h = setup([dreamEntry({ sessionCount: 7 })]); + h.call(() => h.probe.current!.actions.openDialog()); + + const f = h.lastFrame() ?? ''; + expect(f).toContain('[dream]'); + expect(f).toContain('memory consolidation'); + expect(f).toContain('reviewing 7 sessions'); + }); + + it('renders DreamDetailBody with sessions / progress / topics on detail view', () => { + const h = setup([ + dreamEntry({ + status: 'completed', + sessionCount: 5, + progressText: 'Managed auto-memory dream completed.', + touchedTopics: ['user', 'project', 'feedback'], + }), + ]); + h.call(() => h.probe.current!.actions.openDialog()); + h.call(() => h.probe.current!.actions.enterDetail()); + + const f = h.lastFrame() ?? ''; + expect(f).toContain('Dream'); + expect(f).toContain('Sessions reviewing'); + expect(f).toContain('5'); + expect(f).toContain('Progress'); + expect(f).toContain('Managed auto-memory dream completed.'); + expect(f).toContain('Topics touched (3)'); + expect(f).toContain('user'); + expect(f).toContain('project'); + expect(f).toContain('feedback'); + }); + + it('shows the "x stop" hint for a running dream entry', () => { + const h = setup([dreamEntry({ status: 'running' })]); + h.call(() => h.probe.current!.actions.openDialog()); + const f = h.lastFrame() ?? ''; + expect(f).toContain('x stop'); + }); + + it("routes 'x' on a running dream to MemoryManager.cancelTask(dreamId)", () => { + // Pin the dream-cancel branch in `cancelSelected` — flipping it + // to anything else (e.g. shell's `requestCancel`) would silently + // break the only path the user has to stop a runaway dream + // consolidation, since the hint already advertises the action. + const h = setup([dreamEntry({ dreamId: 'd-zzz', status: 'running' })]); + h.call(() => h.probe.current!.actions.openDialog()); + h.pressKey({ sequence: 'x' }); + expect(h.dreamCancelTask).toHaveBeenCalledWith('d-zzz'); + // Belt-and-braces — the registry-side cancel paths must not fire + // for a dream entry, otherwise the wrong AbortController gets + // signalled. + expect(h.cancel).not.toHaveBeenCalled(); + expect(h.monitorCancel).not.toHaveBeenCalled(); + }); + + it('omits the topics block entirely while the dream is still running', () => { + // Topics only get populated via metadata.touchedTopics on + // completion; mid-run the body should hide the section instead of + // rendering an empty header. + const h = setup([dreamEntry({ status: 'running', touchedTopics: [] })]); + h.call(() => h.probe.current!.actions.openDialog()); + h.call(() => h.probe.current!.actions.enterDetail()); + const f = h.lastFrame() ?? ''; + expect(f).not.toContain('Topics touched'); + }); + + it('renders the Error block on failed status with a "+ Stopped because" verb', () => { + // Dream failures need to surface — they are the user's only signal + // that consolidation didn't happen as expected (success path + // already produces a memory_saved toast in useGeminiStream). + const h = setup([ + dreamEntry({ + status: 'failed', + error: 'Dream agent failed: model timeout', + }), + ]); + h.call(() => h.probe.current!.actions.openDialog()); + h.call(() => h.probe.current!.actions.enterDetail()); + const f = h.lastFrame() ?? ''; + expect(f).toContain('Failed'); + expect(f).toContain('Error'); + expect(f).toContain('Dream agent failed: model timeout'); + }); + + it('caps visible topics at 8 and renders a "+N more" tail for overflow', () => { + // Real consolidations can touch many memory files; the body must + // not push the hint footer off-screen. Cap mirrors MAX_TOPICS in + // DreamDetailBody. + const manyTopics = Array.from({ length: 12 }, (_, i) => `topic-${i + 1}`); + const h = setup([ + dreamEntry({ status: 'completed', touchedTopics: manyTopics }), + ]); + h.call(() => h.probe.current!.actions.openDialog()); + h.call(() => h.probe.current!.actions.enterDetail()); + const f = h.lastFrame() ?? ''; + // First 8 visible. + expect(f).toContain('topic-1'); + expect(f).toContain('topic-8'); + // Past the cap — must NOT be inlined. + expect(f).not.toContain('topic-9'); + expect(f).not.toContain('topic-12'); + // Tail summary. + expect(f).toContain('+4 more'); + // Header still reflects the full count, not the capped slice. + expect(f).toContain('Topics touched (12)'); + }); + }); }); diff --git a/packages/cli/src/ui/components/background-view/BackgroundTasksDialog.tsx b/packages/cli/src/ui/components/background-view/BackgroundTasksDialog.tsx index 3bf40798f..a63b2e386 100644 --- a/packages/cli/src/ui/components/background-view/BackgroundTasksDialog.tsx +++ b/packages/cli/src/ui/components/background-view/BackgroundTasksDialog.tsx @@ -33,6 +33,7 @@ import { formatDuration, formatTokenCount } from '../../utils/formatters.js'; import { type AgentDialogEntry, type DialogEntry, + type DreamDialogEntry, entryId, } from '../../hooks/useBackgroundTaskView.js'; @@ -103,10 +104,20 @@ function terminalStatusPresentation( } } +// Foreground agent rows get this prefix so users can tell at a glance +// that cancelling one will end the parent's current turn — a much heavier +// consequence than cancelling a truly async background entry. +const FOREGROUND_ROW_PREFIX = '[in turn]'; +const SHELL_ROW_PREFIX = '[shell]'; + function rowLabel(entry: DialogEntry): string { switch (entry.kind) { - case 'agent': - return buildBackgroundEntryLabel(entry, { includePrefix: false }); + case 'agent': { + const label = buildBackgroundEntryLabel(entry, { includePrefix: false }); + return entry.flavor === 'foreground' + ? `${FOREGROUND_ROW_PREFIX} ${label}` + : label; + } case 'shell': // Shell / monitor prefixes mirror the dialog's "section" visual hint // without needing per-kind section headers (which would complicate @@ -115,9 +126,16 @@ function rowLabel(entry: DialogEntry): string { // is acceptable for the dialog's information-density profile — // adding `wrap="truncate-end"` here would hide context the user // explicitly opened the dialog to see. - return `[shell] ${entry.command}`; + return `${SHELL_ROW_PREFIX} ${entry.command}`; case 'monitor': return `[monitor] ${entry.description}`; + case 'dream': { + const sessionsHint = + entry.sessionCount !== undefined + ? ` reviewing ${entry.sessionCount} session${entry.sessionCount === 1 ? '' : 's'}` + : ''; + return `[dream] memory consolidation${sessionsHint}`; + } default: { const _exhaustive: never = entry; throw new Error( @@ -282,6 +300,14 @@ const DetailBody: React.FC<{ maxWidth={maxWidth} /> ); + case 'dream': + return ( + + ); default: { const _exhaustive: never = entry; throw new Error( @@ -291,6 +317,189 @@ const DetailBody: React.FC<{ } }; +// ─── Dream detail body ───────────────────────────────────── +// +// Shows what the agent is reviewing (session count), what it has +// touched (topic files, only populated on completion), and the latest +// progress text from MemoryManager. Cancellation is wired through the +// shared `x stop` keystroke (handled by `cancelSelected` in the +// context, which routes dream entries to `MemoryManager.cancelTask`). +// In-flight progress is still static — the dream's fork agent reports +// only at schedule + completion via MemoryManager.update; live +// per-turn phase reporting requires extending runForkedAgent's +// AgentPathParams with an onAssistantMessage callback (separate PR). +// +// Layout follows the Shell/Monitor convention — flat children of +// MaxSizedBox separated by empty `` spacers (nesting a +// `flexDirection="column"` container inside MaxSizedBox eats the +// children silently). +const DreamDetailBody: React.FC<{ + entry: DreamDialogEntry; + maxHeight: number; + maxWidth: number; +}> = ({ entry, maxHeight, maxWidth }) => { + const title = 'Dream'; + const terminal = terminalStatusPresentation(entry.status); + const dimSubtitleParts: string[] = [elapsedFor(entry)]; + if (entry.sessionCount !== undefined) { + dimSubtitleParts.push( + `${entry.sessionCount} session${entry.sessionCount === 1 ? '' : 's'}`, + ); + } + if (entry.touchedTopics && entry.touchedTopics.length > 0) { + dimSubtitleParts.push( + `${entry.touchedTopics.length} topic${entry.touchedTopics.length === 1 ? '' : 's'}`, + ); + } + + // Topic file lists can grow for an active session sweep; cap the + // displayed slice and add a "+N more" tail rather than letting the + // dialog body push the hint footer off-screen. + const MAX_TOPICS = 8; + const topics = entry.touchedTopics ?? []; + const visibleTopics = topics.slice(0, MAX_TOPICS); + const hiddenTopicCount = Math.max(0, topics.length - visibleTopics.length); + const hasError = Boolean(entry.error); + + return ( + + + + {title} + + + + {terminal && ( + + {`${terminal.icon} ${STATUS_VERBS[entry.status]} · `} + + )} + {dimSubtitleParts.join(' · ')} + + + {entry.sessionCount !== undefined && ( + + + + + Sessions reviewing + + + + {String(entry.sessionCount)} + + + )} + + {entry.progressText && ( + + + + + Progress + + + + {entry.progressText} + + + )} + + {topics.length > 0 && ( + + + + + {`Topics touched (${topics.length})`} + + + {visibleTopics.map((topic) => ( + + {` · ${topic}`} + + ))} + {hiddenTopicCount > 0 && ( + + {` · +${hiddenTopicCount} more`} + + )} + + )} + + {hasError && ( + + + + + Error + + + + + {entry.error} + + + + )} + + {/* + Lock-release / metadata-write warnings on a successfully- + completed dream. Rendered as warnings (not errors) so the + terminal status stays Completed; explains why subsequent + dreams may be silently skipped as 'locked' (lock release + failure) or why the scheduler gate isn't picking up the + latest run (metadata write failure). + */} + {entry.lockReleaseError && ( + + + + + Lock release warning + + + + + {entry.lockReleaseError} + + + + + {`Subsequent dreams may be skipped as locked until the next session's staleness sweep cleans the file.`} + + + + )} + {entry.metadataWriteError && ( + + + + + Metadata write warning + + + + + {entry.metadataWriteError} + + + + + {`The scheduler gate did not see this dream's timestamp; the next dream cycle may re-fire sooner than usual.`} + + + + )} + + ); +}; + const AgentDetailBody: React.FC<{ entry: AgentDialogEntry; maxHeight: number; @@ -640,6 +849,15 @@ export const BackgroundTasksDialog: React.FC = ({ // those transitions — so we re-read from the registry here. const [activityTick, setActivityTick] = useState(0); + // Two-step cancel for foreground entries: cancelling one ends the + // parent's current turn with a partial result for that subagent — + // a much heavier consequence than cancelling a background async task. + // `pendingCancelEntryId` records the entry that has been armed for + // cancellation; the next `x` press confirms. Esc resets. + const [pendingCancelEntryId, setPendingCancelEntryId] = useState< + string | null + >(null); + const selectedEntry = useMemo(() => { const fromSnapshot = entries[selectedIndex] ?? null; if (!fromSnapshot) return fromSnapshot; @@ -748,6 +966,32 @@ export const BackgroundTasksDialog: React.FC = ({ } }, [dialogOpen, dialogMode, selectedEntryId, selectedStatus, exitDetail]); + // Encapsulates the cancel flow with the foreground confirm-step. + // Foreground entries: first `x` arms; second `x` confirms. Background + // and shell entries: one-shot cancel (no behavior change). + const handleCancelKey = () => { + if (!selectedEntry) return; + // `x` only has a meaning for entries the user can still act on: + // `running` → cancel, `paused` (agent kind) → abandon. Terminal + // statuses (completed/failed/cancelled) ignore the keypress so a + // foreground entry that just settled can't display the misleading + // "x again to confirm stop" line during the brief window before it + // unregisters. + const isCancelable = selectedEntry.status === 'running'; + const isAbandonable = + selectedEntry.kind === 'agent' && selectedEntry.status === 'paused'; + if (!isCancelable && !isAbandonable) return; + const entryKey = entryId(selectedEntry); + const isForegroundAgent = + selectedEntry.kind === 'agent' && selectedEntry.flavor === 'foreground'; + if (isForegroundAgent && pendingCancelEntryId !== entryKey) { + setPendingCancelEntryId(entryKey); + return; + } + setPendingCancelEntryId(null); + cancelSelected(); + }; + useKeypress( (key) => { if (!dialogOpen) return; @@ -755,10 +999,12 @@ export const BackgroundTasksDialog: React.FC = ({ if (dialogMode === 'list') { if (key.name === 'up') { moveSelectionUp(); + setPendingCancelEntryId(null); return; } if (key.name === 'down') { moveSelectionDown(); + setPendingCancelEntryId(null); return; } if (key.name === 'return') { @@ -766,6 +1012,11 @@ export const BackgroundTasksDialog: React.FC = ({ return; } if (key.name === 'escape' || key.name === 'left') { + if (pendingCancelEntryId) { + // Esc backs out of the confirm step before closing the dialog. + setPendingCancelEntryId(null); + return; + } closeDialog(); return; } @@ -774,7 +1025,7 @@ export const BackgroundTasksDialog: React.FC = ({ return; } if (key.sequence === 'x' && !key.ctrl && !key.meta) { - cancelSelected(); + handleCancelKey(); return; } // Note: the "stop all agents" chord (ctrl+x ctrl+k in claw-code) @@ -787,6 +1038,10 @@ export const BackgroundTasksDialog: React.FC = ({ // detail mode if (key.name === 'left') { + // Reset the foreground confirm-step before leaving detail so the + // armed state can't carry into list mode and turn a stray `x` into + // an unintended cancel on the same entry. + setPendingCancelEntryId(null); exitDetail(); return; } @@ -795,6 +1050,10 @@ export const BackgroundTasksDialog: React.FC = ({ key.name === 'return' || key.name === 'space' ) { + if (pendingCancelEntryId && key.name === 'escape') { + setPendingCancelEntryId(null); + return; + } closeDialog(); return; } @@ -803,7 +1062,7 @@ export const BackgroundTasksDialog: React.FC = ({ return; } if (key.sequence === 'x' && !key.ctrl && !key.meta) { - cancelSelected(); + handleCancelKey(); return; } }, @@ -818,8 +1077,18 @@ export const BackgroundTasksDialog: React.FC = ({ !selectedEntry.resumeBlockedReason; // Hint footer — context-sensitive. + const selectedEntryKey = selectedEntry ? entryId(selectedEntry) : null; + const showCancelConfirmHint = + pendingCancelEntryId !== null && pendingCancelEntryId === selectedEntryKey; const hints: string[] = []; - if (dialogMode === 'list') { + if (showCancelConfirmHint) { + // Force the confirmation step into the hint row so the user sees + // exactly what the next `x` will do. + hints.push( + 'x again to confirm stop \u00b7 ends current turn', + 'Esc cancel', + ); + } else if (dialogMode === 'list') { hints.push('\u2191/\u2193 select', 'Enter view'); if (selectedEntry?.status === 'running') hints.push('x stop'); if (selectedEntryAllowsResume) hints.push('r resume'); diff --git a/packages/cli/src/ui/components/background-view/BackgroundTasksPill.test.tsx b/packages/cli/src/ui/components/background-view/BackgroundTasksPill.test.tsx index 5721d890e..2647056e7 100644 --- a/packages/cli/src/ui/components/background-view/BackgroundTasksPill.test.tsx +++ b/packages/cli/src/ui/components/background-view/BackgroundTasksPill.test.tsx @@ -34,6 +34,17 @@ function shellEntry(overrides: Partial = {}): DialogEntry { } as DialogEntry; } +function dreamEntry(overrides: Partial = {}): DialogEntry { + return { + kind: 'dream', + dreamId: 'd-1', + status: 'running', + startTime: 0, + sessionCount: 5, + ...overrides, + } as DialogEntry; +} + function monitorEntry(overrides: Partial = {}): DialogEntry { return { kind: 'monitor', @@ -153,4 +164,44 @@ describe('getPillLabel', () => { ]), ).toBe('2 tasks done'); }); + + it('uses singular form for one running dream', () => { + expect(getPillLabel([dreamEntry({ dreamId: 'd-1' })])).toBe('1 dream'); + }); + + it('uses plural form for multiple running dreams', () => { + expect( + getPillLabel([ + dreamEntry({ dreamId: 'd-1' }), + dreamEntry({ dreamId: 'd-2' }), + ]), + ).toBe('2 dreams'); + }); + + it('places dream last in the kind ordering (shell, agent, monitor, dream)', () => { + // Ordering is asserted explicitly because it's a UX choice — dream + // is system-initiated (not user-triggered) and the user is least + // likely to need it at a glance, so it sits to the right of the + // user-launched kinds. + expect( + getPillLabel([ + dreamEntry({ dreamId: 'd-1' }), + agentEntry({ agentId: 'a' }), + shellEntry({ shellId: 'bg_a' }), + monitorEntry({ monitorId: 'mon-a' }), + ]), + ).toBe('1 shell, 1 local agent, 1 monitor, 1 dream'); + }); + + it('counts only running dreams when terminal dreams mix in', () => { + // Mirrors the existing monitor + agent terminal-mix tests so dream + // gets the same coverage profile. + expect( + getPillLabel([ + dreamEntry({ dreamId: 'd-a', status: 'running' }), + dreamEntry({ dreamId: 'd-b', status: 'completed' }), + dreamEntry({ dreamId: 'd-c', status: 'failed' }), + ]), + ).toBe('1 dream'); + }); }); diff --git a/packages/cli/src/ui/components/background-view/BackgroundTasksPill.tsx b/packages/cli/src/ui/components/background-view/BackgroundTasksPill.tsx index 99b08887c..347b80fa9 100644 --- a/packages/cli/src/ui/components/background-view/BackgroundTasksPill.tsx +++ b/packages/cli/src/ui/components/background-view/BackgroundTasksPill.tsx @@ -19,6 +19,7 @@ const KIND_NAMES = { agent: { singular: 'local agent', plural: 'local agents' }, shell: { singular: 'shell', plural: 'shells' }, monitor: { singular: 'monitor', plural: 'monitors' }, + dream: { singular: 'dream', plural: 'dreams' }, } as const; /** @@ -48,15 +49,17 @@ export function getPillLabel(entries: readonly DialogEntry[]): string { } function groupAndFormat(entries: readonly DialogEntry[]): string { - const counts = { agent: 0, shell: 0, monitor: 0 }; + const counts = { agent: 0, shell: 0, monitor: 0, dream: 0 }; for (const e of entries) counts[e.kind]++; const parts: string[] = []; // Order: shell first (matches Claude Code's pill convention), then - // agent, then monitor. Monitor sits last because it tends to be the - // longest-lived entry and least urgent to glance at. + // agent, then monitor, then dream. Dream sits last because it is + // system-initiated (not user-triggered) and the user is least likely + // to need it at a glance. if (counts.shell > 0) parts.push(formatCount('shell', counts.shell)); if (counts.agent > 0) parts.push(formatCount('agent', counts.agent)); if (counts.monitor > 0) parts.push(formatCount('monitor', counts.monitor)); + if (counts.dream > 0) parts.push(formatCount('dream', counts.dream)); return parts.join(', '); } diff --git a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx index 2520b19b0..4a78ccdb5 100644 --- a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx @@ -48,6 +48,14 @@ interface ToolGroupMessageProps { availableTerminalHeight?: number; contentWidth: number; isFocused?: boolean; + /** + * True when this tool group is being rendered live (in + * `pendingHistoryItems`). False once it commits to Ink's ``. + * The subagent renderer uses this to suppress the live frame for + * foreground subagents (the pill+dialog handle live drill-down) while + * keeping the committed scrollback render unchanged. + */ + isPending?: boolean; activeShellPtyId?: number | null; embeddedShellFocused?: boolean; onShellInputSubmit?: (input: string) => void; @@ -70,6 +78,7 @@ export const ToolGroupMessage: React.FC = ({ availableTerminalHeight, contentWidth, isFocused = true, + isPending = false, activeShellPtyId, embeddedShellFocused, memoryWriteCount, @@ -319,6 +328,7 @@ export const ToolGroupMessage: React.FC = ({ isAgentWithPendingConfirmation(tool.resultDisplay) } isFocused={isSubagentFocused} + isPending={isPending} isWaitingForOtherApproval={isWaitingForOtherApproval} /> diff --git a/packages/cli/src/ui/components/messages/ToolMessage.test.tsx b/packages/cli/src/ui/components/messages/ToolMessage.test.tsx index 8e2fb0d33..568c9c12c 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.test.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.test.tsx @@ -114,6 +114,13 @@ vi.mock('../subagents/index.js', () => ({ ); }, })); +vi.mock('./ToolConfirmationMessage.js', () => ({ + ToolConfirmationMessage: function MockToolConfirmationMessage() { + // Sentinel string lets `isPending && pendingConfirmation` tests + // assert the banner renders (instead of being suppressed). + return MockApprovalPrompt; + }, +})); // Mock settings const mockSettings: LoadedSettings = { @@ -316,6 +323,147 @@ describe('', () => { expect(output).toContain('Search for files matching pattern'); // Actual task description }); + describe('subagent live-render gating (isPending)', () => { + // The redesign hides the inline AgentExecutionDisplay while a + // foreground subagent runs (the pill+dialog handle drill-down). + // Only an active, focused approval prompt renders inline. + const buildProps = (overrides: { + data: { + subagentName: string; + taskDescription: string; + taskPrompt: string; + status: 'running' | 'completed'; + pendingConfirmation?: object; + }; + isPending?: boolean; + isFocused?: boolean; + isWaitingForOtherApproval?: boolean; + }): ToolMessageProps => { + // Spread the existing typed defaults so any future required field + // on `ToolMessageProps` becomes a compile-time miss instead of + // silently defaulting to undefined. Only the agent-specific + // `resultDisplay` shape uses a cast — its `pendingConfirmation` + // intentionally accepts a loose `object` fixture for these tests. + const resultDisplay = { + type: 'task_execution' as const, + ...overrides.data, + } as ToolMessageProps['resultDisplay']; + return { + ...baseProps, + name: 'task', + description: 'Delegate task to subagent', + resultDisplay, + status: ToolCallStatus.Executing, + callId: 'gated-task-call', + forceShowResult: true, // mirror ToolGroupMessage's forceShowResult + isPending: overrides.isPending, + isFocused: overrides.isFocused, + isWaitingForOtherApproval: overrides.isWaitingForOtherApproval, + }; + }; + + it('isPending && no pendingConfirmation → no inline frame', () => { + const { lastFrame } = renderWithContext( + , + StreamingState.Responding, + ); + const output = lastFrame() ?? ''; + // The mocked AgentExecutionDisplay tags itself with '🤖' — its + // absence proves the inline frame was suppressed. + expect(output).not.toContain('🤖'); + expect(output).not.toContain('MockApprovalPrompt'); + }); + + it('isPending && pendingConfirmation && isFocused → renders banner with agent label', () => { + const { lastFrame } = renderWithContext( + , + StreamingState.Responding, + ); + const output = lastFrame() ?? ''; + // Banner shows up with the originating agent identified, and the + // approval prompt itself renders. + expect(output).toContain('Approval requested by'); + expect(output).toContain('fg-agent'); + expect(output).toContain('MockApprovalPrompt'); + // The full agent frame (header / tool-call list) stays suppressed. + expect(output).not.toContain('🤖'); + }); + + it('isPending && pendingConfirmation && !isFocused → renders queued marker (one-line)', () => { + // Without this marker, a subagent waiting on another subagent's + // approval would be invisible in the main view — the user would + // have no inline signal that an approval is queued and would have + // to open the dialog to discover it. + const { lastFrame } = renderWithContext( + , + StreamingState.Responding, + ); + const output = lastFrame() ?? ''; + expect(output).toContain('Queued approval:'); + expect(output).toContain('queued-agent'); + // The full prompt + frame stay suppressed — only the focus-holder + // renders the active prompt above this row. + expect(output).not.toContain('Approval requested by'); + expect(output).not.toContain('MockApprovalPrompt'); + expect(output).not.toContain('🤖'); + }); + + it('!isPending → committed render shows full inline frame', () => { + const { lastFrame } = renderWithContext( + , + StreamingState.Idle, + ); + const output = lastFrame() ?? ''; + // -rendered scrollback: full frame, no flicker concern. + expect(output).toContain('🤖'); + expect(output).toContain('committed-agent'); + }); + }); + it('renders AnsiOutputText for AnsiOutput results', () => { const ansiResult: AnsiOutput = [ [ diff --git a/packages/cli/src/ui/components/messages/ToolMessage.tsx b/packages/cli/src/ui/components/messages/ToolMessage.tsx index a7a78f1a9..ebdfed812 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.tsx @@ -24,6 +24,7 @@ import type { McpToolProgressData, } from '@qwen-code/qwen-code-core'; import { AgentExecutionDisplay } from '../subagents/index.js'; +import { ToolConfirmationMessage } from './ToolConfirmationMessage.js'; import { PlanSummaryDisplay } from '../PlanSummaryDisplay.js'; import { ShellInputPrompt } from '../ShellInputPrompt.js'; import { SHELL_COMMAND_NAME, SHELL_NAME } from '../../constants.js'; @@ -238,7 +239,19 @@ const PlanResultRenderer: React.FC<{ ); /** - * Component to render subagent execution results + * Component to render subagent execution results. + * + * Live (`isPending===true`): the inline frame is suppressed — running + * subagents are surfaced through the footer pill + dialog instead, which + * removes the live-area flicker that occurred when the frame's tool-call + * list grew past the terminal height. The one exception is an active + * approval prompt that holds the focus lock: that renders as a small + * banner with an agent-name label, since hiding it would block the run + * silently. + * + * Committed (`isPending===false`): renders the full `AgentExecutionDisplay` + * exactly as before. Ink's `` is append-only, so committed frames + * never flicker even when verbose. */ const SubagentExecutionRenderer: React.FC<{ data: AgentResultDisplay; @@ -246,6 +259,7 @@ const SubagentExecutionRenderer: React.FC<{ childWidth: number; config: Config; isFocused?: boolean; + isPending?: boolean; isWaitingForOtherApproval?: boolean; }> = ({ data, @@ -253,17 +267,62 @@ const SubagentExecutionRenderer: React.FC<{ childWidth, config, isFocused, + isPending, isWaitingForOtherApproval, -}) => ( - -); +}) => { + if (isPending) { + if (data.pendingConfirmation && isFocused) { + // Active approval prompt for the focus-holding subagent — render + // inline so the user can act on it without opening the dialog. + const agentLabel = data.subagentName || 'agent'; + return ( + + + Approval requested by + + {agentLabel} + + : + + + + ); + } + if (data.pendingConfirmation) { + // Queued approval — another subagent currently holds the focus lock. + // A one-line marker keeps the user aware that something is waiting + // without opening the dialog; the full prompt renders on the + // focus-holder above and inside `BackgroundTasksDialog`. + const agentLabel = data.subagentName || 'agent'; + return ( + + + ⏳ Queued approval:{' '} + + {agentLabel} + + ); + } + return null; + } + return ( + + ); +}; /** * Component to render string results (markdown or plain text) @@ -347,6 +406,13 @@ export interface ToolMessageProps extends IndividualToolCallDisplay { * Ctrl+E/Ctrl+F display shortcuts. */ isFocused?: boolean; + /** + * True when rendering inside `pendingHistoryItems` (live area), false once + * committed to ``. Foreground subagents suppress their inline + * frame in the live phase — the pill+dialog handle drill-down — but + * always render in scrollback. + */ + isPending?: boolean; /** Whether another subagent's approval currently holds the focus lock, blocking this one. */ isWaitingForOtherApproval?: boolean; } @@ -366,6 +432,7 @@ export const ToolMessage: React.FC = ({ config, forceShowResult, isFocused, + isPending, isWaitingForOtherApproval, executionStartTime, }) => { @@ -521,6 +588,7 @@ export const ToolMessage: React.FC = ({ childWidth={innerWidth} config={config} isFocused={isFocused} + isPending={isPending} isWaitingForOtherApproval={isWaitingForOtherApproval} /> )} diff --git a/packages/cli/src/ui/contexts/BackgroundTaskViewContext.tsx b/packages/cli/src/ui/contexts/BackgroundTaskViewContext.tsx index cc95d4f7b..4fb02f5f9 100644 --- a/packages/cli/src/ui/contexts/BackgroundTaskViewContext.tsx +++ b/packages/cli/src/ui/contexts/BackgroundTaskViewContext.tsx @@ -19,12 +19,14 @@ import { useMemo, useState, } from 'react'; -import { type Config } from '@qwen-code/qwen-code-core'; +import { type Config, createDebugLogger } from '@qwen-code/qwen-code-core'; import { type DialogEntry, useBackgroundTaskView, } from '../hooks/useBackgroundTaskView.js'; +const debugLogger = createDebugLogger('BG_TASK_VIEW'); + // ─── Types ────────────────────────────────────────────────── export type BackgroundDialogMode = 'closed' | 'list' | 'detail'; @@ -194,6 +196,28 @@ export function BackgroundTaskViewProvider({ case 'monitor': config.getMonitorRegistry().cancel(target.monitorId); break; + case 'dream': { + // Aborts the dream fork-agent via MemoryManager.cancelTask; + // the manager flips status to 'cancelled' before aborting, and + // the runDream finally block releases the consolidation lock as + // the agent unwinds. Same one-shot fire-and-forget shape as + // shell.requestCancel above. + // + // cancelTask returns false in the contract-violation path + // (running record without an AbortController). Today this is + // unreachable because the controller is registered before + // storeWith fires the notify, but if a future refactor + // breaks the invariant a silent ignore here would let the + // user think the cancel took. Log + leave the dialog open. + const ok = config.getMemoryManager().cancelTask(target.dreamId); + if (!ok) { + debugLogger.warn( + `cancelSelected: dream task ${target.dreamId} could not be cancelled ` + + `(internal state inconsistency — see MemoryManager.cancelTask warn).`, + ); + } + break; + } default: { const _exhaustive: never = target; throw new Error( diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 7b963cfd2..074490b4c 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -6,7 +6,10 @@ import { act, renderHook, waitFor } from '@testing-library/react'; import { vi, describe, it, expect, beforeEach } from 'vitest'; -import { useSlashCommandProcessor } from './slashCommandProcessor.js'; +import { + useSlashCommandProcessor, + type SlashCommandProcessorActions, +} from './slashCommandProcessor.js'; import type { CommandContext, ConfirmShellCommandsActionReturn, @@ -121,6 +124,33 @@ describe('useSlashCommandProcessor', () => { }); const mockSettings = {} as LoadedSettings; + const createMockActions = (): SlashCommandProcessorActions => ({ + openAuthDialog: mockOpenAuthDialog, + openArenaDialog: vi.fn(), + openThemeDialog: mockOpenThemeDialog, + openEditorDialog: vi.fn(), + openMemoryDialog: mockOpenMemoryDialog, + openSettingsDialog: vi.fn(), + openModelDialog: mockOpenModelDialog, + openManageModelsDialog: vi.fn(), + openTrustDialog: vi.fn(), + openPermissionsDialog: vi.fn(), + openApprovalModeDialog: vi.fn(), + openResumeDialog: vi.fn(), + handleResume: vi.fn(), + openDeleteDialog: vi.fn(), + quit: mockSetQuittingMessages, + setDebugMessage: vi.fn(), + dispatchExtensionStateUpdate: vi.fn(), + addConfirmUpdateExtensionRequest: vi.fn(), + openSubagentCreateDialog: vi.fn(), + openAgentsManagerDialog: vi.fn(), + openExtensionsManagerDialog: vi.fn(), + openMcpDialog: vi.fn(), + openHooksDialog: vi.fn(), + openRewindSelector: vi.fn(), + }); + beforeEach(() => { vi.clearAllMocks(); vi.mocked(BuiltinCommandLoader).mockClear(); @@ -154,24 +184,7 @@ describe('useSlashCommandProcessor', () => { setIsProcessing, { current: true }, // isIdleRef vi.fn(), // setGeminiMdFileCount - { - openAuthDialog: mockOpenAuthDialog, - openThemeDialog: mockOpenThemeDialog, - openEditorDialog: vi.fn(), - openMemoryDialog: mockOpenMemoryDialog, - openSettingsDialog: vi.fn(), - openModelDialog: mockOpenModelDialog, - openTrustDialog: vi.fn(), - openPermissionsDialog: vi.fn(), - openApprovalModeDialog: vi.fn(), - openResumeDialog: vi.fn(), - quit: mockSetQuittingMessages, - setDebugMessage: vi.fn(), - dispatchExtensionStateUpdate: vi.fn(), - addConfirmUpdateExtensionRequest: vi.fn(), - openSubagentCreateDialog: vi.fn(), - openAgentsManagerDialog: vi.fn(), - }, + createMockActions(), new Map(), // extensionsUpdateState true, // isConfigInitialized null, // logger @@ -270,6 +283,21 @@ describe('useSlashCommandProcessor', () => { ); }); + it('should let slash-prefixed file paths fall through to the model', async () => { + const result = setupProcessorHook(); + await waitFor(() => expect(result.current.slashCommands).toBeDefined()); + + let actionResult; + await act(async () => { + actionResult = await result.current.handleSlashCommand( + '/api/apiFunction/接口的实现', + ); + }); + + expect(actionResult).toBe(false); + expect(mockAddItem).not.toHaveBeenCalled(); + }); + it('should display help for a parent command invoked without a subcommand', async () => { const parentCommand: SlashCommand = { name: 'parent', @@ -968,24 +996,7 @@ describe('useSlashCommandProcessor', () => { vi.fn(), // setIsProcessing { current: true }, // isIdleRef vi.fn(), // setGeminiMdFileCount - { - openAuthDialog: mockOpenAuthDialog, - openThemeDialog: mockOpenThemeDialog, - openEditorDialog: vi.fn(), - openMemoryDialog: mockOpenMemoryDialog, - openSettingsDialog: vi.fn(), - openModelDialog: vi.fn(), - openTrustDialog: vi.fn(), - openPermissionsDialog: vi.fn(), - openApprovalModeDialog: vi.fn(), - openResumeDialog: vi.fn(), - quit: mockSetQuittingMessages, - setDebugMessage: vi.fn(), - dispatchExtensionStateUpdate: vi.fn(), - addConfirmUpdateExtensionRequest: vi.fn(), - openSubagentCreateDialog: vi.fn(), - openAgentsManagerDialog: vi.fn(), - }, + createMockActions(), new Map(), // extensionsUpdateState true, // isConfigInitialized null, // logger diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index f9d685b78..55586ed0c 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -46,7 +46,10 @@ import { FileCommandLoader } from '../../services/FileCommandLoader.js'; import { McpPromptLoader } from '../../services/McpPromptLoader.js'; import { SkillCommandLoader } from '../../services/SkillCommandLoader.js'; import { parseSlashCommand } from '../../utils/commands.js'; -import { isBtwCommand } from '../utils/commandUtils.js'; +import { + hasSlashCommandPathSeparator, + isBtwCommand, +} from '../utils/commandUtils.js'; import { clearScreen } from '../../utils/stdioHelpers.js'; import { useKeypress } from './useKeypress.js'; import { @@ -78,7 +81,7 @@ const SLASH_COMMANDS_SKIP_RECORDING = new Set([ 'btw', ]); -interface SlashCommandProcessorActions { +export interface SlashCommandProcessorActions { openAuthDialog: () => void; openArenaDialog?: (type: Exclude) => void; openThemeDialog: () => void; @@ -448,6 +451,9 @@ export const useSlashCommandProcessor = ( if (!trimmed.startsWith('/') && !trimmed.startsWith('?')) { return false; } + if (trimmed.startsWith('/') && hasSlashCommandPathSeparator(trimmed)) { + return false; + } const recordedItems: Array> = []; const recordItem = (item: Omit) => { diff --git a/packages/cli/src/ui/hooks/useBackgroundTaskView.test.ts b/packages/cli/src/ui/hooks/useBackgroundTaskView.test.ts index 7c71f22d8..eb4d8aabe 100644 --- a/packages/cli/src/ui/hooks/useBackgroundTaskView.test.ts +++ b/packages/cli/src/ui/hooks/useBackgroundTaskView.test.ts @@ -25,14 +25,52 @@ function makeFakeRegistry(): FakeRegistry { }; } +interface FakeMemoryManager { + subscribe: ReturnType; + unsubscribe: ReturnType; + /** Captured opts from the most recent subscribe() call (the hook + * passes `{ taskType: 'dream' }` to skip per-extract notifies). */ + lastSubscribeOpts: { taskType?: 'extract' | 'dream' } | undefined; + /** Test helper — invokes the currently-subscribed listener. */ + fire: () => void; +} + +function makeFakeMemoryManager(): FakeMemoryManager { + let listener: (() => void) | undefined; + const ref: { lastSubscribeOpts: FakeMemoryManager['lastSubscribeOpts'] } = { + lastSubscribeOpts: undefined, + }; + const unsubscribe = vi.fn(() => { + listener = undefined; + }); + const subscribe = vi.fn( + (next: () => void, opts?: { taskType?: 'extract' | 'dream' }) => { + listener = next; + ref.lastSubscribeOpts = opts; + return unsubscribe; + }, + ); + return { + subscribe, + unsubscribe, + get lastSubscribeOpts() { + return ref.lastSubscribeOpts; + }, + fire: () => listener?.(), + }; +} + function makeConfig(opts: { agents: () => unknown[]; shells: () => unknown[]; monitors: () => unknown[]; + dreams?: () => unknown[]; }) { const agentReg = makeFakeRegistry(); const shellReg = makeFakeRegistry(); const monitorReg = makeFakeRegistry(); + const memoryMgr = makeFakeMemoryManager(); + const dreams = opts.dreams ?? (() => []); const config = { getBackgroundTaskRegistry: () => ({ @@ -47,9 +85,16 @@ function makeConfig(opts: { ...monitorReg, getAll: opts.monitors, }), + getMemoryManager: () => ({ + subscribe: memoryMgr.subscribe, + // Hook only ever requests dream-typed records; ignore the type arg + // and return whatever the test provided. + listTasksByType: (_type: string, _projectRoot?: string) => dreams(), + }), + getProjectRoot: () => '/test/project', } as unknown as Config; - return { config, agentReg, shellReg, monitorReg }; + return { config, agentReg, shellReg, monitorReg, memoryMgr }; } const agent = (id: string, startTime: number) => ({ @@ -84,6 +129,38 @@ const monitor = (id: string, startTime: number) => ({ droppedLines: 0, }); +// Mirror the MemoryTaskRecord shape that MemoryManager.listTasksByType +// returns. Status defaults to 'running'; tests override to exercise the +// filter (`pending` / `skipped` records must be excluded; `cancelled` +// flows through the same terminal-cap path as `completed` / `failed` +// once the task_stop / dialog cancel keystroke lands one). +const dream = ( + id: string, + startTimeMs: number, + overrides: Partial<{ + status: + | 'pending' + | 'running' + | 'completed' + | 'failed' + | 'cancelled' + | 'skipped'; + progressText: string; + error: string; + metadata: Record; + }> = {}, +) => ({ + id, + taskType: 'dream' as const, + projectRoot: '/test/project', + status: overrides.status ?? ('running' as const), + createdAt: new Date(startTimeMs).toISOString(), + updatedAt: new Date(startTimeMs).toISOString(), + progressText: overrides.progressText, + error: overrides.error, + metadata: overrides.metadata, +}); + describe('useBackgroundTaskView', () => { it('returns empty entries when config is null', () => { const { result } = renderHook(() => useBackgroundTaskView(null)); @@ -154,7 +231,7 @@ describe('useBackgroundTaskView', () => { }); it('clears all three subscriptions on unmount', () => { - const { config, agentReg, shellReg, monitorReg } = makeConfig({ + const { config, agentReg, shellReg, monitorReg, memoryMgr } = makeConfig({ agents: () => [], shells: () => [], monitors: () => [], @@ -178,5 +255,169 @@ describe('useBackgroundTaskView', () => { [expect.any(Function)], [undefined], ]); + // MemoryManager uses subscribe()/unsubscribe rather than the + // setCallback pattern; the unsubscribe returned from subscribe must + // run on cleanup or stale dream listeners leak across remounts. + expect(memoryMgr.subscribe).toHaveBeenCalledTimes(1); + expect(memoryMgr.unsubscribe).toHaveBeenCalledTimes(1); + }); + + it('surfaces dream tasks with kind=dream and skips pending/skipped records', () => { + const { config } = makeConfig({ + agents: () => [], + shells: () => [], + monitors: () => [], + // Three dream records covering: a pre-fire pending record (must + // not surface — would flood the dialog with one row per + // UserQuery), a running fire (must surface), and a skipped + // gate-miss (must not surface — same flood concern). + dreams: () => [ + dream('d-pending', 100, { status: 'pending' }), + dream('d-running', 200), + dream('d-skipped', 300, { status: 'skipped' }), + ], + }); + const { result } = renderHook(() => useBackgroundTaskView(config)); + expect(result.current.entries).toHaveLength(1); + const [only] = result.current.entries; + expect(only.kind).toBe('dream'); + expect(only.status).toBe('running'); + expect(entryId(only)).toBe('d-running'); + }); + + it('caps retained terminal dream entries at 3 most-recent (by updatedAt) plus all running', () => { + // MemoryManager has no eviction; without the cap, accumulating + // completed dreams across a long session would blow up the dialog. + // The cap keeps the dialog glanceable while still surfacing the + // most recent outcomes (mirrors MonitorRegistry's terminal cap). + const baseMs = Date.parse('2026-05-04T12:00:00.000Z'); + const completed = (id: string, mtime: number) => ({ + id, + taskType: 'dream' as const, + projectRoot: '/test/project', + status: 'completed' as const, + createdAt: new Date(baseMs + mtime - 1000).toISOString(), + updatedAt: new Date(baseMs + mtime).toISOString(), + }); + const { config } = makeConfig({ + agents: () => [], + shells: () => [], + monitors: () => [], + dreams: () => [ + completed('d-old-1', 1_000), + completed('d-old-2', 2_000), + completed('d-mid', 3_000), + completed('d-recent', 4_000), + completed('d-newest', 5_000), + // Plus a running entry that must always survive the cap (caps + // only trim terminals; running dreams are uncapped). + dream('d-running-now', baseMs + 6_000, { status: 'running' }), + ], + }); + const { result } = renderHook(() => useBackgroundTaskView(config)); + const ids = result.current.entries.map(entryId).sort(); + // Surviving terminal entries: d-newest, d-recent, d-mid (top 3 by + // updatedAt desc). The two oldest (d-old-1, d-old-2) get dropped. + // The running dream survives unconditionally. + expect(ids).toEqual( + ['d-mid', 'd-newest', 'd-recent', 'd-running-now'].sort(), + ); + }); + + it('surfaces a cancelled dream with kind=dream so the dialog can render the terminal status', () => { + // `'cancelled'` arrives via the dialog `x stop` / `task_stop` path + // which routes through `MemoryManager.cancelTask`. The view-model + // must accept it the same way it accepts `'completed'` / `'failed'`, + // because the dialog's terminal-cap window depends on showing the + // user the outcome of the abort they just triggered. + const { config } = makeConfig({ + agents: () => [], + shells: () => [], + monitors: () => [], + dreams: () => [dream('d-stopped', 100, { status: 'cancelled' })], + }); + const { result } = renderHook(() => useBackgroundTaskView(config)); + expect(result.current.entries).toHaveLength(1); + const [only] = result.current.entries; + expect(only.kind).toBe('dream'); + expect(only.status).toBe('cancelled'); + }); + + it('subscribes to MemoryManager with a dream taskType filter so extract notifies are skipped at the source', () => { + // The taskType filter on MemoryManager.subscribe() is the + // primary perf guard — it prevents the per-UserQuery extract + // notify from waking the bg-tasks UI listener at all (avoids the + // O(n) dream-snapshot fetch + signature compare that would + // otherwise run on every extract transition). Pin the filter so + // a future refactor that drops the opts arg fails the test + // rather than silently re-introducing the wakeups. + const { config, memoryMgr } = makeConfig({ + agents: () => [], + shells: () => [], + monitors: () => [], + }); + renderHook(() => useBackgroundTaskView(config)); + expect(memoryMgr.subscribe).toHaveBeenCalledTimes(1); + expect(memoryMgr.lastSubscribeOpts).toEqual({ taskType: 'dream' }); + }); + + it('skips setEntries when the memory listener fires with unchanged dream content', () => { + // MemoryManager.subscribe() fires for ALL task transitions, including + // extract task records that have no dialog surface. Without the + // dream-signature dedup, every extract notify would trigger a full + // re-merge + a fresh array reference into setEntries — re-rendering + // the dialog and pill on entries that are byte-identical to the + // previous snapshot. This test pins the dedup by firing the memory + // listener while the dream snapshot stays unchanged and asserting + // that the entries reference is preserved. + const dreams: Array> = [dream('d-only', 100)]; + const { config, memoryMgr } = makeConfig({ + agents: () => [], + shells: () => [], + monitors: () => [], + dreams: () => dreams, + }); + const { result } = renderHook(() => useBackgroundTaskView(config)); + const before = result.current.entries; + expect(before.map(entryId)).toEqual(['d-only']); + + // Fire the memory listener without mutating `dreams`. With the + // signature-dedup in place, this must NOT call setEntries; React + // will then preserve the existing array reference. + act(() => memoryMgr.fire()); + expect(result.current.entries).toBe(before); + + // Sanity check the inverse path: when dreams DO change, the + // listener must propagate. A flipped status should change the + // signature and force a fresh setEntries. + dreams.splice(0, 1, dream('d-only', 100, { status: 'completed' })); + act(() => memoryMgr.fire()); + expect(result.current.entries).not.toBe(before); + expect(result.current.entries[0]?.status).toBe('completed'); + }); + + it('refreshes entries when the memory manager fires its subscribe listener', () => { + const dreams: Array> = []; + const { config, memoryMgr } = makeConfig({ + agents: () => [], + shells: () => [], + monitors: () => [], + dreams: () => dreams, + }); + const { result } = renderHook(() => useBackgroundTaskView(config)); + expect(result.current.entries).toEqual([]); + + dreams.push(dream('d-1', 100)); + act(() => memoryMgr.fire()); + expect(result.current.entries.map(entryId)).toEqual(['d-1']); + + // A subsequent terminal state update must propagate the new status + // (running → completed) and survive the filter (only pending / + // skipped get dropped). + dreams.splice(0, dreams.length, dream('d-1', 100, { status: 'completed' })); + act(() => memoryMgr.fire()); + const [only] = result.current.entries; + expect(only.kind).toBe('dream'); + expect(only.status).toBe('completed'); }); }); diff --git a/packages/cli/src/ui/hooks/useBackgroundTaskView.ts b/packages/cli/src/ui/hooks/useBackgroundTaskView.ts index 7edd3c48a..dd6c3c496 100644 --- a/packages/cli/src/ui/hooks/useBackgroundTaskView.ts +++ b/packages/cli/src/ui/hooks/useBackgroundTaskView.ts @@ -5,11 +5,15 @@ */ /** - * useBackgroundTaskView — subscribes to all three registries (background - * subagents, managed shells, and event monitors) and merges them into a - * single ordered snapshot of `DialogEntry`s. Each registry fires + * useBackgroundTaskView — subscribes to the three background-task + * registries (background subagents, managed shells, and event monitors) + * AND to `MemoryManager` for dream consolidation tasks, merging them + * into a single ordered snapshot of `DialogEntry`s. Each registry fires * `statusChange` on register too, so a single subscription per registry * is enough to keep the snapshot fresh for new + transitioning entries. + * The `MemoryManager.subscribe({ taskType: 'dream' })` filter routes + * dream-task transitions to the same refresh path while skipping the + * per-UserQuery extract notifies that have no dialog surface. * * Surfaces that only care about live work (the footer pill, the * composer's Down-arrow route) filter for `running` themselves. @@ -26,24 +30,75 @@ import { type BackgroundTaskEntry, type BackgroundShellEntry, type Config, + type MemoryTaskRecord, type MonitorEntry, } from '@qwen-code/qwen-code-core'; +// Cap on retained terminal dream entries surfaced via the dialog. +// `MemoryManager.tasks` has no eviction; without this cap the list +// grows unboundedly with completed dreams over the project's lifetime. +// 3 is small enough to stay glanceable yet keeps the most recent +// outcomes visible across rapid succession (e.g. the user opening the +// dialog right after two dreams completed). +const MAX_RETAINED_TERMINAL_DREAMS = 3; + export type AgentDialogEntry = BackgroundTaskEntry & { kind: 'agent'; resumeBlockedReason?: string; }; +/** + * Dream-task adapter. MemoryManager owns its own task records + * (MemoryTaskRecord) and intentionally lives outside the registry trio; + * this view-model wraps the subset of fields the dialog needs and + * narrows status to the four values that ever appear in the dialog + * (skipped/pending records are filtered out at the source). + */ +export type DreamDialogEntry = { + kind: 'dream'; + /** MemoryTaskRecord.id — used as React key + lookup. */ + dreamId: string; + status: 'running' | 'completed' | 'failed' | 'cancelled'; + startTime: number; + /** + * Wall-clock instant the record's `status` last changed. For + * `completed` / `failed` this is when the dream actually finished; + * for `cancelled` this is the moment `cancelTask` ran (NOT when + * the fork agent finishes unwinding — that can lag by seconds for + * agents mid-tool-call). The dialog renders elapsed from this + * value, so a freshly-cancelled record snaps to "Stopped · Ns" + * even while the underlying fork is still releasing the lock. + */ + endTime?: number; + progressText?: string; + error?: string; + /** Number of sessions the dream is reviewing — populated on schedule. */ + sessionCount?: number; + /** Memory topic files written — populated on completion. */ + touchedTopics?: readonly string[]; + /** + * Best-effort warnings populated by `runDream` when post-fork + * housekeeping fails (gating-metadata write or consolidation-lock + * release). The dream itself completed successfully — these are + * informational so the user can explain why subsequent dreams may + * be silently skipped as `'locked'` or why the scheduler gate + * isn't seeing the most recent dream's timestamp. + */ + lockReleaseError?: string; + metadataWriteError?: string; +}; + /** * A unified view-model entry the dialog/pill/context render against. * Discriminated by `kind`; per-kind fields are inlined verbatim so * renderer code can stay mechanical (`entry.kind === 'agent'` / - * `'shell'` / `'monitor'` guard, then access fields directly). + * `'shell'` / `'monitor'` / `'dream'` guard, then access fields directly). */ export type DialogEntry = | AgentDialogEntry | (BackgroundShellEntry & { kind: 'shell' }) - | (MonitorEntry & { kind: 'monitor' }); + | (MonitorEntry & { kind: 'monitor' }) + | DreamDialogEntry; export interface UseBackgroundTaskViewResult { entries: readonly DialogEntry[]; @@ -58,6 +113,8 @@ export function entryId(entry: DialogEntry): string { return entry.shellId; case 'monitor': return entry.monitorId; + case 'dream': + return entry.dreamId; default: { const _exhaustive: never = entry; throw new Error( @@ -77,8 +134,27 @@ export function useBackgroundTaskView( const agentRegistry = config.getBackgroundTaskRegistry(); const shellRegistry = config.getBackgroundShellRegistry(); const monitorRegistry = config.getMonitorRegistry(); + const memoryManager = config.getMemoryManager(); + const projectRoot = config.getProjectRoot(); + // Dream snapshot signature, kept as a defense-in-depth dedup for + // the dream-filtered memory listener below. The taskType filter + // already skips the listener entirely on extract notifies; this + // signature additionally absorbs the rare case where dream + // metadata is updated without an observable dialog change. + let lastDreamSig = ''; - const refresh = () => { + // Declared before `refresh` so the function ordering can't trip + // the temporal-dead-zone if a future refactor adds a synchronous + // call to refresh between the two `const` bindings. + const computeDreamSig = (dreams: readonly MemoryTaskRecord[]): string => + dreams.map((t) => `${t.id}:${t.status}:${t.updatedAt}`).join('|'); + + // refresh accepts a pre-fetched dream snapshot so the memory + // listener can reuse the same array it computed for its dedup + // check — avoids a second listTasksByType call AND eliminates the + // race window where the listener's gate sig and the entries it + // builds would otherwise come from two separate snapshots. + const refresh = (dreamSnapshot?: readonly MemoryTaskRecord[]) => { const agentEntries: DialogEntry[] = agentRegistry .getAll() .map((e) => ({ ...e, kind: 'agent' as const })); @@ -88,25 +164,119 @@ export function useBackgroundTaskView( const monitorEntries: DialogEntry[] = monitorRegistry .getAll() .map((e) => ({ ...e, kind: 'monitor' as const })); + // Dream entries: only surface tasks that actually fired. + // `pending` is a sub-second transition state and `skipped` + // records arise from the rare race where the schedule-time + // lock check passed but `acquireDreamLock` then hit EEXIST in + // runDream — these never reflect user-visible work, so filter + // them out. (Most gate misses don't create a record at all; + // scheduleDream returns `{status: 'skipped'}` early without + // touching the task map.) Extract tasks also intentionally + // stay out of this view — they fire on every UserQuery and + // their completion is already covered by the `memory_saved` + // toast in useGeminiStream. + // + // Cap retained terminal entries — MemoryManager.tasks Map has no + // eviction path, so completed/failed dreams accumulate forever + // (every fired dream over the project's lifetime). Without this + // cap the dialog would grow unbounded; with it the user sees all + // running dreams plus the most recent few terminal results + // (mirrors MonitorRegistry.MAX_RETAINED_TERMINAL_MONITORS). + const allDreams = + dreamSnapshot ?? memoryManager.listTasksByType('dream', projectRoot); + const runningDreams = allDreams.filter((t) => t.status === 'running'); + const terminalDreams = allDreams + .filter( + (t) => + t.status === 'completed' || + t.status === 'failed' || + t.status === 'cancelled', + ) + .sort((a, b) => b.updatedAt.localeCompare(a.updatedAt)) + .slice(0, MAX_RETAINED_TERMINAL_DREAMS); + const dreamEntries: DialogEntry[] = [ + ...runningDreams, + ...terminalDreams, + ].map((t) => { + const sessionCount = t.metadata?.['sessionCount']; + const touchedTopics = t.metadata?.['touchedTopics']; + const lockReleaseError = t.metadata?.['lockReleaseError']; + const metadataWriteError = t.metadata?.['metadataWriteError']; + return { + kind: 'dream' as const, + dreamId: t.id, + status: t.status as 'running' | 'completed' | 'failed' | 'cancelled', + startTime: Date.parse(t.createdAt), + endTime: t.status === 'running' ? undefined : Date.parse(t.updatedAt), + progressText: t.progressText, + error: t.error, + sessionCount: + typeof sessionCount === 'number' ? sessionCount : undefined, + touchedTopics: Array.isArray(touchedTopics) + ? (touchedTopics.filter((s) => typeof s === 'string') as string[]) + : undefined, + lockReleaseError: + typeof lockReleaseError === 'string' ? lockReleaseError : undefined, + metadataWriteError: + typeof metadataWriteError === 'string' + ? metadataWriteError + : undefined, + }; + }); // Merge by startTime so the order matches launch order across all - // registries (matters when an agent, shell, and monitor are + // sources (matters when an agent, shell, monitor, and dream are // launched alternately). - const merged = [...agentEntries, ...shellEntries, ...monitorEntries].sort( - (a, b) => a.startTime - b.startTime, - ); + const merged = [ + ...agentEntries, + ...shellEntries, + ...monitorEntries, + ...dreamEntries, + ].sort((a, b) => a.startTime - b.startTime); + // Cache the dream signature derived from the freshly-built + // entries — the memory listener uses this to skip redundant + // setEntries calls when an extract notify fires (extract has no + // dialog surface, so the merged result is identical). Computed + // from the same `allDreams` snapshot used to build dreamEntries + // so the gate value can never desync from what's on screen. + lastDreamSig = computeDreamSig(allDreams); setEntries(merged); }; + // Wrap registry callbacks in a thunk so React's setStatusChange + // signature (no-arg) doesn't accidentally pass an entry into + // refresh's `dreamSnapshot` parameter. + const refreshFromRegistry = () => refresh(); + refresh(); - agentRegistry.setStatusChangeCallback(refresh); - shellRegistry.setStatusChangeCallback(refresh); - monitorRegistry.setStatusChangeCallback(refresh); + agentRegistry.setStatusChangeCallback(refreshFromRegistry); + shellRegistry.setStatusChangeCallback(refreshFromRegistry); + monitorRegistry.setStatusChangeCallback(refreshFromRegistry); + + // Memory listener fires only on dream-task transitions — + // `subscribe({ taskType: 'dream' })` skips the per-extract notify + // entirely so we don't pay the per-UserQuery O(n) signature cost + // for transitions we have no surface for. The dream-content + // signature dedup remains as a second-line guard against the rare + // case where dream metadata is updated without observable changes + // to the dialog (e.g. a future progressText-only patch on the + // same status). The fetched snapshot is forwarded to refresh so + // both the gate and the rendered dreamEntries come from one read. + const memoryListener = () => { + const dreams = memoryManager.listTasksByType('dream', projectRoot); + const sig = computeDreamSig(dreams); + if (sig === lastDreamSig) return; + refresh(dreams); + }; + const unsubscribeMemory = memoryManager.subscribe(memoryListener, { + taskType: 'dream', + }); return () => { agentRegistry.setStatusChangeCallback(undefined); shellRegistry.setStatusChangeCallback(undefined); monitorRegistry.setStatusChangeCallback(undefined); + unsubscribeMemory(); }; }, [config]); diff --git a/packages/cli/src/ui/hooks/useExportCompletion.test.ts b/packages/cli/src/ui/hooks/useExportCompletion.test.ts new file mode 100644 index 000000000..bb86a900f --- /dev/null +++ b/packages/cli/src/ui/hooks/useExportCompletion.test.ts @@ -0,0 +1,259 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** @vitest-environment jsdom */ + +import { act, renderHook } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; +import { CommandKind, type SlashCommand } from '../commands/types.js'; +import type { TextBuffer } from '../components/shared/text-buffer.js'; +import type { Key } from './useKeypress.js'; +import type { UseCommandCompletionReturn } from './useCommandCompletion.js'; +import { + getExportFormatFromInput, + getNextExportCompletionIndex, + useExportCompletion, +} from './useExportCompletion.js'; + +const EXPORT_FORMATS = ['html', 'md', 'json', 'jsonl'] as const; + +const exportSlashCommands: readonly SlashCommand[] = [ + { + name: 'export', + kind: CommandKind.BUILT_IN, + description: 'Export conversation', + subCommands: EXPORT_FORMATS.map((name) => ({ + name, + kind: CommandKind.BUILT_IN, + description: `Export ${name}`, + })), + }, +]; + +function createTextBuffer(initialText: string): TextBuffer { + let text = initialText; + const buffer = { + get text() { + return text; + }, + get lines() { + return [text]; + }, + get cursor() { + return [0, text.length] as [number, number]; + }, + setText: vi.fn((nextText: string) => { + text = nextText; + }), + }; + + return buffer as unknown as TextBuffer; +} + +function createKey(name: string): Key { + return { + name, + ctrl: false, + meta: false, + shift: false, + paste: false, + sequence: '', + }; +} + +function createCompletion( + overrides: Partial = {}, +): UseCommandCompletionReturn { + return { + suggestions: EXPORT_FORMATS.map((format) => ({ + label: format, + value: format, + })), + activeSuggestionIndex: 0, + visibleStartIndex: 0, + showSuggestions: false, + isLoadingSuggestions: false, + isPerfectMatch: false, + setActiveSuggestionIndex: vi.fn(), + setShowSuggestions: vi.fn(), + resetCompletionState: vi.fn(), + navigateUp: vi.fn(), + navigateDown: vi.fn(), + handleAutocomplete: vi.fn(), + midInputGhostText: null, + ...overrides, + }; +} + +describe('getExportFormatFromInput', () => { + it.each([ + ['', null], + ['/export', null], + ['/export ', null], + ['/export yaml', null], + ['/export md extra', null], + ['/help md', null], + ['/export md', 'md'], + [' /export jsonl ', 'jsonl'], + ])('parses %j as %j', (input, expected) => { + expect(getExportFormatFromInput(input, EXPORT_FORMATS)).toBe(expected); + }); + + it('returns null when there are no valid formats', () => { + expect(getExportFormatFromInput('/export md', [])).toBeNull(); + }); +}); + +describe('getNextExportCompletionIndex', () => { + it('returns the current index for an empty format list', () => { + expect(getNextExportCompletionIndex([], 3, 'down')).toBe(3); + }); + + it('wraps downward at the end of the list', () => { + expect(getNextExportCompletionIndex(EXPORT_FORMATS, 3, 'down')).toBe(0); + }); + + it('wraps upward at the start of the list', () => { + expect(getNextExportCompletionIndex(EXPORT_FORMATS, 0, 'up')).toBe(3); + }); + + it('moves from out-of-range indexes back into the cycle', () => { + expect(getNextExportCompletionIndex(EXPORT_FORMATS, -1, 'down')).toBe(0); + expect(getNextExportCompletionIndex(EXPORT_FORMATS, 99, 'down')).toBe(0); + expect(getNextExportCompletionIndex(EXPORT_FORMATS, -1, 'up')).toBe(3); + }); + + it('keeps a single-item list on the only index', () => { + expect(getNextExportCompletionIndex(['html'], 0, 'down')).toBe(0); + expect(getNextExportCompletionIndex(['html'], 0, 'up')).toBe(0); + }); +}); + +describe('useExportCompletion', () => { + it('returns null display props outside export cycling', () => { + const buffer = createTextBuffer('/export'); + const { result } = renderHook(() => + useExportCompletion(buffer, exportSlashCommands), + ); + + expect(result.current.shouldShowSuggestions).toBe(false); + expect(result.current.suggestionDisplayProps).toBeNull(); + }); + + it('does not seed cycling state from buffer text alone', () => { + const buffer = createTextBuffer('/export md'); + const completion = createCompletion(); + const { result } = renderHook(() => + useExportCompletion(buffer, exportSlashCommands), + ); + + let consumed = true; + act(() => { + consumed = result.current.handleExportInput( + createKey('down'), + completion, + ); + }); + + expect(consumed).toBe(false); + expect(buffer.setText).not.toHaveBeenCalled(); + }); + + it('seeds cycling state after a marked user text edit', () => { + const buffer = createTextBuffer(''); + const completion = createCompletion(); + const { result, rerender } = renderHook( + ({ textBuffer }) => useExportCompletion(textBuffer, exportSlashCommands), + { initialProps: { textBuffer: buffer } }, + ); + + act(() => { + result.current.markNextTextChangeAsUserInput(); + buffer.setText('/export md'); + }); + rerender({ textBuffer: buffer }); + vi.mocked(buffer.setText).mockClear(); + + let consumed = false; + act(() => { + consumed = result.current.handleExportInput( + createKey('down'), + completion, + ); + }); + + expect(consumed).toBe(true); + expect(buffer.setText).toHaveBeenLastCalledWith('/export json'); + }); + + it('clears refs and cycling state on reset', () => { + const buffer = createTextBuffer('/export md'); + const completion = createCompletion(); + const { result } = renderHook(() => + useExportCompletion(buffer, exportSlashCommands), + ); + + act(() => { + result.current.navigatedRef.current = true; + result.current.navigatedTextRef.current = '/memory'; + result.current.reset(); + }); + + expect(result.current.navigatedRef.current).toBe(false); + expect(result.current.navigatedTextRef.current).toBe(''); + + let consumed = true; + act(() => { + consumed = result.current.handleExportInput( + createKey('down'), + completion, + ); + }); + + expect(consumed).toBe(false); + expect(buffer.setText).not.toHaveBeenCalled(); + }); + + it('shows export suggestions after phase-1 cycling updates the buffer', () => { + const buffer = createTextBuffer('/export'); + const completion = createCompletion({ + showSuggestions: true, + isPerfectMatch: true, + }); + const { result, rerender } = renderHook( + ({ textBuffer }) => useExportCompletion(textBuffer, exportSlashCommands), + { initialProps: { textBuffer: buffer } }, + ); + + act(() => { + result.current.handleExportInput(createKey('down'), completion); + }); + rerender({ textBuffer: buffer }); + + expect(result.current.shouldShowSuggestions).toBe(true); + expect(result.current.suggestionDisplayProps).toMatchObject({ + activeIndex: 1, + isLoading: false, + scrollOffset: 0, + }); + expect( + result.current.suggestionDisplayProps?.suggestions.map((s) => s.value), + ).toEqual(['html', 'md', 'json', 'jsonl']); + }); + + it('keeps the returned object stable when dependencies do not change', () => { + const buffer = createTextBuffer(''); + const { result, rerender } = renderHook( + ({ textBuffer }) => useExportCompletion(textBuffer, exportSlashCommands), + { initialProps: { textBuffer: buffer } }, + ); + const firstResult = result.current; + + rerender({ textBuffer: buffer }); + + expect(result.current).toBe(firstResult); + }); +}); diff --git a/packages/cli/src/ui/hooks/useExportCompletion.ts b/packages/cli/src/ui/hooks/useExportCompletion.ts new file mode 100644 index 000000000..f69b1e9cf --- /dev/null +++ b/packages/cli/src/ui/hooks/useExportCompletion.ts @@ -0,0 +1,342 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useCallback, useEffect, useMemo, useRef } from 'react'; +import type { TextBuffer } from '../components/shared/text-buffer.js'; +import type { Suggestion } from '../components/SuggestionsDisplay.js'; +import type { SlashCommand } from '../commands/types.js'; +import type { Key } from './useKeypress.js'; +import { keyMatchers, Command } from '../keyMatchers.js'; +import type { UseCommandCompletionReturn } from './useCommandCompletion.js'; + +const EXPORT_COMMAND_INPUT = '/export'; + +/** + * Parse a single export format from an input buffer. + * + * The valid-format list is passed in so that adding a new "/export " + * sub-command to slashCommands automatically enables Phase-2 cycling for it, + * without requiring a synchronous hard-coded regex update. + * + * Uses a simple slice-based approach (no regex) for two reasons: + * 1. No escaping concerns when format names contain regex metacharacters. + * 2. O(1) cost after the cheap startsWith prefix guard. + */ +export const getExportFormatFromInput = ( + input: string, + validFormats: readonly string[], +): string | null => { + const trimmed = input.trim(); + if (!trimmed.startsWith(EXPORT_COMMAND_INPUT + ' ')) { + return null; + } + const rest = trimmed.slice(EXPORT_COMMAND_INPUT.length + 1); + if (!rest || rest.includes(' ')) { + return null; + } + return validFormats.includes(rest) ? rest : null; +}; + +/** + * Compute the next index for export format cycling (round-robin). + * Extracted as a module-level pure function to avoid per-keystroke + * closure recreation inside handleInput. + */ +export const getNextExportCompletionIndex = ( + formatList: readonly string[], + currentIndex: number, + direction: 'up' | 'down', +) => { + const total = formatList.length; + if (total === 0) { + return currentIndex; + } + const lastIndex = total - 1; + if (direction === 'up') { + return currentIndex <= 0 ? lastIndex : currentIndex - 1; + } + return currentIndex >= lastIndex ? 0 : currentIndex + 1; +}; + +export interface ExportCompletionResult { + /** Whether the suggestions panel should be visible (export-specific). */ + shouldShowSuggestions: boolean; + /** + * Display props for the SuggestionsDisplay component when export + * suggestions are active, or null if the caller should fall back + * to the generic completion state. + */ + suggestionDisplayProps: { + suggestions: Suggestion[]; + activeIndex: number; + isLoading: boolean; + scrollOffset: number; + } | null; + /** + * Handle a keypress for export-specific completion logic. + * Returns true if the key was consumed, false if the caller should + * fall through to generic completion handling. + */ + handleExportInput: ( + key: Key, + completion: UseCommandCompletionReturn, + ) => boolean; + /** Reset all export cycling state (call on ESC / Ctrl+C / Ctrl+U / submit). */ + reset: () => void; + /** + * Allow the next buffer text change to seed export cycling if it becomes + * exactly "/export ". Call this only for direct user text edits. + */ + markNextTextChangeAsUserInput: () => void; + /** + * Shared "has navigated" flag. The generic completion path sets this + * to true on arrow navigation and the isPerfectMatch + Enter path reads + * it. Owned by this hook so both the export-specific and generic paths + * share a single source of truth. + */ + navigatedRef: React.MutableRefObject; + /** + * Buffer text snapshot captured when navigatedRef was last set to true. + * Used by the caller to detect stale navigation state when the buffer + * has been modified externally (e.g. via setText in tests). + */ + navigatedTextRef: React.MutableRefObject; +} + +export function useExportCompletion( + buffer: TextBuffer, + slashCommands: readonly SlashCommand[], +): ExportCompletionResult { + const navigatedRef = useRef(false); + const navigatedTextRef = useRef(''); + const cyclingActiveRef = useRef(false); + const nextTextChangeWasUserInputRef = useRef(false); + + // Derive the canonical export format list from slashCommands so adding a + // new "/export " sub-command automatically enables arrow/Tab cycling. + const exportFormatSuggestions = useMemo(() => { + const exportCommand = slashCommands.find( + (command) => command.name === EXPORT_COMMAND_INPUT.slice(1), + ); + const subCommands = exportCommand?.subCommands; + if (subCommands && subCommands.length > 0) { + return subCommands.map((command) => ({ + label: command.name, + value: command.name, + description: command.description, + commandKind: command.kind, + })); + } + return []; + }, [slashCommands]); + + // Cache the export format names (keys only) so the cycle logic inside + // handleInput does not call .map() on every keystroke. + const exportCycleFormats = useMemo( + () => exportFormatSuggestions.map((s) => s.value), + [exportFormatSuggestions], + ); + + const markNextTextChangeAsUserInput = useCallback(() => { + nextTextChangeWasUserInputRef.current = true; + }, []); + + // Seed cyclingActiveRef only for text changes that InputPrompt marked as + // direct user edits. History navigation and programmatic setText() calls can + // also produce "/export ", but they must not steal the next Up/Down key + // from the normal history/navigation handlers. + useEffect(() => { + const fmt = getExportFormatFromInput(buffer.text, exportCycleFormats); + if ( + nextTextChangeWasUserInputRef.current && + fmt !== null && + !cyclingActiveRef.current + ) { + cyclingActiveRef.current = true; + } + nextTextChangeWasUserInputRef.current = false; + }, [buffer.text, exportCycleFormats]); + + // Reset navigated flag on every popup visibility transition (true↔false) + // and on every buffer text change, to prevent flag stickiness when the + // user navigates, then backspaces and retypes the command. + useEffect(() => { + navigatedRef.current = false; + navigatedTextRef.current = ''; + }, [buffer.text]); + + const reset = useCallback(() => { + cyclingActiveRef.current = false; + nextTextChangeWasUserInputRef.current = false; + navigatedRef.current = false; + navigatedTextRef.current = ''; + }, []); + + const getExportIndexForActiveSuggestion = useCallback( + (completion: UseCommandCompletionReturn): number => { + const idx = completion.activeSuggestionIndex; + if (idx < 0 || idx >= completion.suggestions.length) { + return -1; + } + return exportCycleFormats.indexOf(completion.suggestions[idx].value); + }, + [exportCycleFormats], + ); + + const setExportCompletionInput = useCallback( + (index: number): boolean => { + const format = exportCycleFormats[index]; + if (!format) return false; + buffer.setText(`${EXPORT_COMMAND_INPUT} ${format}`); + cyclingActiveRef.current = true; + navigatedRef.current = false; + return true; + }, + [buffer, exportCycleFormats], + ); + + const handleExportInput = useCallback( + (key: Key, completion: UseCommandCompletionReturn): boolean => { + const isCompletionUpKey = keyMatchers[Command.COMPLETION_UP](key); + const isCompletionDownKey = keyMatchers[Command.COMPLETION_DOWN](key); + const isCompletionTabKey = + key.name === 'tab' && + !key.shift && + !key.ctrl && + !key.meta && + !key.paste; + + // ---- Phase 1 detection (popup is showing pure "/export") ---- + const hasExportFormatSuggestions = + buffer.text.trim() === EXPORT_COMMAND_INPUT && + completion.suggestions.length > 0 && + exportCycleFormats.length > 0 && + exportCycleFormats.every((format) => + completion.suggestions.some((s) => s.value === format), + ); + + // ---- Phase 2 guard ---- + const parsedFormat = getExportFormatFromInput( + buffer.text, + exportCycleFormats, + ); + + // Phase-2 cycling: buffer is "/export " and cycling is active. + if ( + cyclingActiveRef.current && + parsedFormat !== null && + !key.ctrl && + !key.meta && + !key.paste && + (isCompletionUpKey || isCompletionDownKey || isCompletionTabKey) + ) { + const direction = isCompletionUpKey ? 'up' : 'down'; + const currentIndex = exportCycleFormats.indexOf(parsedFormat); + const nextIndex = getNextExportCompletionIndex( + exportCycleFormats, + currentIndex, + direction, + ); + setExportCompletionInput(nextIndex); + return true; + } + + if (!completion.showSuggestions) { + return false; + } + + // ---- Phase 1: popup is visible ---- + if (completion.suggestions.length > 1) { + if (isCompletionUpKey || isCompletionDownKey) { + if (hasExportFormatSuggestions) { + const currentIdx = getExportIndexForActiveSuggestion(completion); + if (currentIdx !== -1) { + const nextIdx = getNextExportCompletionIndex( + exportCycleFormats, + currentIdx, + isCompletionUpKey ? 'up' : 'down', + ); + setExportCompletionInput(nextIdx); + return true; + } + } + } + } + + if (keyMatchers[Command.ACCEPT_SUGGESTION](key) && !key.paste) { + if ( + hasExportFormatSuggestions && + !(completion.isPerfectMatch && keyMatchers[Command.RETURN](key)) + ) { + const exportIdx = getExportIndexForActiveSuggestion(completion); + if (exportIdx !== -1) { + setExportCompletionInput(exportIdx); + return true; + } + } + } + + return false; + }, + [ + buffer, + exportCycleFormats, + getExportIndexForActiveSuggestion, + setExportCompletionInput, + ], + ); + + // ---- Render-time derivations ---- + const selectedExportFormat = getExportFormatFromInput( + buffer.text, + exportCycleFormats, + ); + const selectedExportFormatIndex = + selectedExportFormat === null + ? -1 + : exportFormatSuggestions.findIndex( + (s) => s.value === selectedExportFormat, + ); + + const shouldShowSuggestions = + !cyclingActiveRef.current || selectedExportFormatIndex === -1 + ? false + : true; + + const suggestionDisplayProps = useMemo< + ExportCompletionResult['suggestionDisplayProps'] + >( + () => + shouldShowSuggestions + ? { + suggestions: exportFormatSuggestions, + activeIndex: selectedExportFormatIndex, + isLoading: false, + scrollOffset: 0, + } + : null, + [exportFormatSuggestions, selectedExportFormatIndex, shouldShowSuggestions], + ); + + return useMemo( + () => ({ + shouldShowSuggestions, + suggestionDisplayProps, + handleExportInput, + reset, + markNextTextChangeAsUserInput, + navigatedRef, + navigatedTextRef, + }), + [ + shouldShowSuggestions, + suggestionDisplayProps, + handleExportInput, + reset, + markNextTextChangeAsUserInput, + ], + ); +} diff --git a/packages/cli/src/ui/utils/commandUtils.test.ts b/packages/cli/src/ui/utils/commandUtils.test.ts index 4273bc1ea..37c6b3322 100644 --- a/packages/cli/src/ui/utils/commandUtils.test.ts +++ b/packages/cli/src/ui/utils/commandUtils.test.ts @@ -118,6 +118,15 @@ describe('commandUtils', () => { expect(isSlashCommand('/*\n * Multi-line comment\n */')).toBe(false); expect(isSlashCommand('/*comment without space*/')).toBe(false); }); + + it('should return false for slash-prefixed file paths', () => { + expect(isSlashCommand('/api/apiFunction/接口的实现')).toBe(false); + expect(isSlashCommand('/Users/me/project/src/index.ts')).toBe(false); + expect(isSlashCommand('/var/log/syslog check this')).toBe(false); + expect(isSlashCommand('/home/user/.qwen/settings.json')).toBe(false); + expect(isSlashCommand('/tmp/test.txt')).toBe(false); + expect(isSlashCommand('/tmp\\test.txt')).toBe(false); + }); }); describe('copyToClipboard', () => { diff --git a/packages/cli/src/ui/utils/commandUtils.ts b/packages/cli/src/ui/utils/commandUtils.ts index 18f74015c..ed023b971 100644 --- a/packages/cli/src/ui/utils/commandUtils.ts +++ b/packages/cli/src/ui/utils/commandUtils.ts @@ -38,9 +38,18 @@ export const isAtCommand = (query: string): boolean => // Check if starts with @ OR has a space, then @ query.startsWith('@') || /\s@/.test(query); +const SLASH_PATH_SEPARATOR_RE = /[/\\]/; + +const getSlashCommandFirstToken = (query: string): string => + query.slice(1).trimStart().split(/\s+/)[0] ?? ''; + +export const hasSlashCommandPathSeparator = (query: string): boolean => + SLASH_PATH_SEPARATOR_RE.test(getSlashCommandFirstToken(query)); + /** * Checks if a query string potentially represents an '/' command. - * It triggers if the query starts with '/' but excludes code comments like '//' and '/*'. + * It triggers if the query starts with '/' but excludes code comments like '//' + * and '/*', and file paths where the first token contains a path separator. * * @param query The input query string. * @returns True if the query looks like an '/' command, false otherwise. @@ -60,6 +69,10 @@ export const isSlashCommand = (query: string): boolean => { return false; } + if (hasSlashCommandPathSeparator(query)) { + return false; + } + return true; }; diff --git a/packages/cli/src/ui/utils/historyMapping.test.ts b/packages/cli/src/ui/utils/historyMapping.test.ts index 84c18a2ff..8f6426a6d 100644 --- a/packages/cli/src/ui/utils/historyMapping.test.ts +++ b/packages/cli/src/ui/utils/historyMapping.test.ts @@ -208,6 +208,27 @@ describe('computeApiTruncationIndex', () => { // Slash '/help' (id=3) should not be counted expect(computeApiTruncationIndex(ui, 5, api)).toBe(2); }); + + it('counts path-like slash prompts that were sent to the model', () => { + const ui: HistoryItem[] = [ + userItem(1, 'hello'), + geminiItem(2), + userItem(3, '/api/apiFunction/接口的实现'), + geminiItem(4), + userItem(5, 'world'), + geminiItem(6), + ]; + const api: Content[] = [ + userContent('hello'), + modelContent('response 1'), + userContent('/api/apiFunction/接口的实现'), + modelContent('response 2'), + userContent('world'), + modelContent('response 3'), + ]; + + expect(computeApiTruncationIndex(ui, 5, api)).toBe(4); + }); }); describe('single turn', () => { @@ -233,6 +254,15 @@ describe('isRealUserTurn', () => { expect(isRealUserTurn(userItem(1, '/stats'))).toBe(false); }); + it('returns true for path-like slash prompts', () => { + expect(isRealUserTurn(userItem(1, '/api/apiFunction/接口的实现'))).toBe( + true, + ); + expect(isRealUserTurn(userItem(1, '/Users/name/project 帮我安装'))).toBe( + true, + ); + }); + it('returns false for ? commands', () => { expect(isRealUserTurn(userItem(1, '?help'))).toBe(false); }); diff --git a/packages/cli/src/ui/utils/historyMapping.ts b/packages/cli/src/ui/utils/historyMapping.ts index 389acf2e9..6b5bb9d41 100644 --- a/packages/cli/src/ui/utils/historyMapping.ts +++ b/packages/cli/src/ui/utils/historyMapping.ts @@ -6,6 +6,7 @@ import type { HistoryItem } from '../types.js'; import type { Content } from '@google/genai'; +import { isSlashCommand } from './commandUtils.js'; /** * Returns true when the history item represents a real user prompt that was @@ -15,7 +16,7 @@ import type { Content } from '@google/genai'; */ export function isRealUserTurn(item: HistoryItem): boolean { if (item.type !== 'user' || !item.text) return false; - return !item.text.startsWith('/') && !item.text.startsWith('?'); + return !isSlashCommand(item.text) && !item.text.startsWith('?'); } /** diff --git a/packages/core/src/agents/background-tasks.test.ts b/packages/core/src/agents/background-tasks.test.ts index 44f92025c..3c781ec28 100644 --- a/packages/core/src/agents/background-tasks.test.ts +++ b/packages/core/src/agents/background-tasks.test.ts @@ -5,7 +5,10 @@ */ import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { BackgroundTaskRegistry } from './background-tasks.js'; +import { + BackgroundTaskRegistry, + type BackgroundTaskEntry, +} from './background-tasks.js'; import * as transcript from './agent-transcript.js'; describe('BackgroundTaskRegistry', () => { @@ -909,4 +912,216 @@ describe('BackgroundTaskRegistry', () => { expect(modelText).not.toContain(''); }); }); + + describe('foreground flavor', () => { + it('does not emit a task-notification on complete', () => { + const callback = vi.fn(); + registry.setNotificationCallback(callback); + + registry.register({ + agentId: 'fg-1', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + registry.complete('fg-1', 'result text'); + + // Foreground entries deliver their result through the parent's normal + // tool-result channel; emitting the XML envelope on top would feed + // the parent model the same payload twice. + expect(callback).not.toHaveBeenCalled(); + // The status mutation still happens — internal invariants intact. + expect(registry.get('fg-1')!.status).toBe('completed'); + expect(registry.get('fg-1')!.notified).toBe(true); + }); + + it('does not emit a task-notification on fail', () => { + const callback = vi.fn(); + registry.setNotificationCallback(callback); + + registry.register({ + agentId: 'fg-2', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + registry.fail('fg-2', 'oops'); + + expect(callback).not.toHaveBeenCalled(); + }); + + it('is excluded from hasUnfinalizedTasks()', () => { + registry.register({ + agentId: 'fg-3', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + // A still-running foreground entry must NOT keep the headless + // event loop alive — the parent's tool-call await already does that. + expect(registry.hasUnfinalizedTasks()).toBe(false); + }); + + it('cancel does not schedule the grace timer', () => { + // The grace-timer fallback only matters for background entries that + // might not see their natural completion handler fire. Foreground + // entries unregister themselves in agent.ts's finally path. + vi.useFakeTimers(); + try { + const callback = vi.fn(); + registry.setNotificationCallback(callback); + + registry.register({ + agentId: 'fg-4', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + registry.cancel('fg-4'); + + // Advance well past the 5s grace window — no notification should fire. + vi.advanceTimersByTime(60_000); + expect(callback).not.toHaveBeenCalled(); + } finally { + vi.useRealTimers(); + } + }); + + it('unregisterForeground removes the entry and emits a status change', () => { + const onStatusChange = vi.fn(); + registry.setStatusChangeCallback(onStatusChange); + + registry.register({ + agentId: 'fg-5', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + onStatusChange.mockClear(); + + registry.unregisterForeground('fg-5'); + + expect(registry.get('fg-5')).toBeUndefined(); + expect(onStatusChange).toHaveBeenCalledTimes(1); + }); + + it('unregisterForeground throws if asked to remove a background entry', () => { + // Background entries must terminate via complete/fail/finalizeCancelled + // so the task-notification + headless holdback invariants stay intact. + // A silent no-op would mask caller bugs, so this throws. + registry.register({ + agentId: 'bg-1', + description: 'async agent', + flavor: 'background', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + expect(() => registry.unregisterForeground('bg-1')).toThrow( + /non-foreground entry bg-1/, + ); + expect(registry.get('bg-1')).toBeDefined(); + }); + + it('unregisterForeground is a no-op for unknown agent ids', () => { + // Idempotent for already-unregistered/never-registered ids — the + // foreground finally path runs unconditionally and shouldn't throw + // if a parallel cancel already cleared the entry. + expect(() => registry.unregisterForeground('missing')).not.toThrow(); + }); + + it('does not invoke the register callback for foreground entries', () => { + // Non-interactive bridges setRegisterCallback to a `task_started` + // SDK event. Foreground entries never produce a paired terminal + // task-notification (see emitNotification's flavor gate), so letting + // them fire `task_started` would leak orphaned in-flight tasks to + // SDK consumers. + const onRegister = vi.fn(); + registry.setRegisterCallback(onRegister); + + registry.register({ + agentId: 'fg-no-register-cb', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + expect(onRegister).not.toHaveBeenCalled(); + + // Background entries still fire it. + registry.register({ + agentId: 'bg-fires-register-cb', + description: 'async agent', + flavor: 'background', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + expect(onRegister).toHaveBeenCalledTimes(1); + expect(onRegister.mock.calls[0]![0].agentId).toBe('bg-fires-register-cb'); + }); + + it('unregisterForeground emits status change before removing the entry', () => { + // Mirrors the ordering used by complete/fail/cancel/finalize so a + // statusChange callback that re-reads `registry.get(agentId)` from + // inside the callback sees the entry across every terminal path. + registry.register({ + agentId: 'fg-unregister-order', + description: 'sync agent', + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + let observedFromCallback: BackgroundTaskEntry | undefined; + registry.setStatusChangeCallback((entry) => { + if (entry?.agentId === 'fg-unregister-order') { + observedFromCallback = registry.get(entry.agentId); + } + }); + + registry.unregisterForeground('fg-unregister-order'); + + expect(observedFromCallback).toBeDefined(); + expect(observedFromCallback!.agentId).toBe('fg-unregister-order'); + expect(registry.get('fg-unregister-order')).toBeUndefined(); + }); + + it('default flavor (absent) behaves as background for emitNotification', () => { + // Older callers omit the flavor field. Backwards compatibility: + // missing flavor is treated as background everywhere. + const callback = vi.fn(); + registry.setNotificationCallback(callback); + + registry.register({ + agentId: 'legacy-1', + description: 'legacy agent', + status: 'running', + startTime: Date.now(), + abortController: new AbortController(), + }); + + registry.complete('legacy-1', 'done'); + + expect(callback).toHaveBeenCalledOnce(); + }); + }); }); diff --git a/packages/core/src/agents/background-tasks.ts b/packages/core/src/agents/background-tasks.ts index 66c7df463..8ef752f70 100644 --- a/packages/core/src/agents/background-tasks.ts +++ b/packages/core/src/agents/background-tasks.ts @@ -5,11 +5,19 @@ */ /** - * @fileoverview BackgroundTaskRegistry — tracks background (async) sub-agents. + * @fileoverview BackgroundTaskRegistry — tracks background (async) sub-agents + * and, with `flavor: 'foreground'`, the currently-running synchronous + * sub-agents whose UI is routed through the same pill+dialog while the + * parent turn waits on them. The two flavors share the registry (and the + * dialog wiring) but differ in lifecycle: * - * When the Agent tool is called with `run_in_background: true`, the sub-agent - * runs asynchronously. This registry tracks the lifecycle of each background - * agent so the parent can be notified on completion. + * - `background` entries persist across turns, emit a `` + * on terminal status (the parent's only return channel), and contribute to + * `hasUnfinalizedTasks()` so headless callers keep their loop alive. + * - `foreground` entries live for the duration of the parent's tool-call, + * are unregistered as soon as `execute()` returns, deliver their result + * through the normal tool-result channel (no XML envelope), and don't + * participate in the headless holdback. */ import { createDebugLogger } from '../utils/debugLogger.js'; @@ -94,10 +102,19 @@ export interface BackgroundActivity { at: number; } +export type BackgroundTaskFlavor = 'foreground' | 'background'; + export interface BackgroundTaskEntry { agentId: string; description: string; subagentType?: string; + /** + * `'background'` — async, persists across turns, emits XML notification. + * `'foreground'` — synchronous, unregistered when the tool-call returns, + * delivers results via the normal tool-result channel. + * Defaults to `'background'` when absent (older callers). + */ + flavor?: BackgroundTaskFlavor; status: BackgroundTaskStatus; startTime: number; endTime?: number; @@ -197,7 +214,12 @@ export class BackgroundTaskRegistry { this.agents.set(entry.agentId, entry); debugLogger.info(`Registered background agent: ${entry.agentId}`); - if (this.registerCallback) { + // Foreground entries are paired with a synchronous tool-call result on + // the parent's response and never emit a terminal `task_notification` + // (see emitNotification's flavor gate). Letting them fire the register + // callback would emit a `task_started` SDK event without a matching + // completion event, breaking the lifecycle contract for SDK consumers. + if (entry.flavor !== 'foreground' && this.registerCallback) { try { this.registerCallback(entry); } catch (error) { @@ -235,6 +257,32 @@ export class BackgroundTaskRegistry { this.emitStatusChange(entry); } + /** + * Remove a foreground entry from the registry without emitting any + * terminal notification. Called by the foreground tool-call's `finally` + * path, which has already delivered the result through the tool-result + * channel — the registry entry has served its UI-surfacing purpose. + * Background entries must go through complete/fail/finalizeCancelled + * instead, so this throws if asked to remove one. + */ + unregisterForeground(agentId: string): void { + const entry = this.agents.get(agentId); + if (!entry) return; + if (entry.flavor !== 'foreground') { + throw new Error( + `unregisterForeground called on non-foreground entry ${agentId} ` + + `(flavor=${entry.flavor ?? 'undefined'}). ` + + `Background entries must terminate via complete/fail/finalizeCancelled.`, + ); + } + // Emit before delete so any future BackgroundStatusChangeCallback that + // re-reads `registry.get(agentId)` from inside the callback sees the + // entry, matching the ordering used by complete/fail/cancel/finalize. + this.emitStatusChange(entry); + this.agents.delete(agentId); + debugLogger.info(`Unregistered foreground agent: ${agentId}`); + } + // See complete() for the cancelled → terminal path rationale. fail(agentId: string, error: string, stats?: AgentCompletionStats): void { const entry = this.agents.get(agentId); @@ -279,6 +327,11 @@ export class BackgroundTaskRegistry { debugLogger.info(`Background agent cancelled: ${agentId}`); this.emitStatusChange(entry); + // Foreground entries don't emit XML notifications and unregister + // themselves in the tool-call's finally path, so the grace timer + // would only ever no-op for them. + if (entry.flavor === 'foreground') return; + if (options.notify === false) { // Session reset paths intentionally suppress the old task's terminal // notification so it cannot leak into a new conversation. @@ -386,6 +439,11 @@ export class BackgroundTaskRegistry { */ hasUnfinalizedTasks(): boolean { for (const entry of this.agents.values()) { + // Foreground entries block the parent tool-call synchronously, so the + // headless event loop is already pinned by the `await` on the caller's + // promise — counting them here would be redundant and would also keep + // the loop alive for entries that don't even emit a notification. + if (entry.flavor === 'foreground') continue; if (entry.status === 'running') return true; if (entry.status === 'cancelled' && !entry.notified) return true; } @@ -493,6 +551,12 @@ export class BackgroundTaskRegistry { if (entry.notified) return; entry.notified = true; + // Foreground entries return their result through the parent's normal + // tool-result channel (the `returnDisplay` field on the synchronous + // tool-call). Emitting the XML envelope on top would feed the parent + // model the same payload twice. + if (entry.flavor === 'foreground') return; + if (!this.notificationCallback) return; const statusText = diff --git a/packages/core/src/agents/runtime/agent-core.ts b/packages/core/src/agents/runtime/agent-core.ts index 8f8202965..bb959f4d0 100644 --- a/packages/core/src/agents/runtime/agent-core.ts +++ b/packages/core/src/agents/runtime/agent-core.ts @@ -317,11 +317,17 @@ export class AgentCore { } try { - return new GeminiChat( + const chat = new GeminiChat( this.runtimeContext, generationConfig, startHistory, ); + // Seed the per-chat token count so the auto-compaction threshold + // gate sees the inherited history's true size on the first send. + // Without this, fork subagents start at 0 and the gate NOOPs even + // when `startHistory` is already huge — first API call can 400. + chat.setLastPromptTokenCount(this.lastPromptTokenCount); + return chat; } catch (error) { await reportError( error, @@ -540,6 +546,18 @@ export class AgentCore { continue; } + // GeminiChat already mutated its own history; surface to the debug + // log so subagent compactions show up alongside the main session's. + if (streamEvent.type === 'compressed') { + this.runtimeContext + .getDebugLogger() + .debug( + `[AGENT-COMPACT] subagent=${this.subagentId} round=${turnCounter} ` + + `tokens ${streamEvent.info.originalTokenCount} -> ${streamEvent.info.newTokenCount}`, + ); + continue; + } + // Handle chunk events if (streamEvent.type === 'chunk') { const resp = streamEvent.value; diff --git a/packages/core/src/agents/runtime/agent-headless.test.ts b/packages/core/src/agents/runtime/agent-headless.test.ts index 9f3f329a2..d54a7af18 100644 --- a/packages/core/src/agents/runtime/agent-headless.test.ts +++ b/packages/core/src/agents/runtime/agent-headless.test.ts @@ -280,6 +280,7 @@ describe('subagent.ts', () => { () => ({ sendMessageStream: mockSendMessageStream, + setLastPromptTokenCount: vi.fn(), }) as unknown as GeminiChat, ); @@ -958,6 +959,7 @@ describe('subagent.ts', () => { () => ({ sendMessageStream: mockSendMessageStream, + setLastPromptTokenCount: vi.fn(), }) as unknown as GeminiChat, ); @@ -997,6 +999,7 @@ describe('subagent.ts', () => { () => ({ sendMessageStream: mockSendMessageStream, + setLastPromptTokenCount: vi.fn(), }) as unknown as GeminiChat, ); @@ -1061,6 +1064,7 @@ describe('subagent.ts', () => { () => ({ sendMessageStream: mockSendMessageStream, + setLastPromptTokenCount: vi.fn(), }) as unknown as GeminiChat, ); diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 5d9207040..b4bc187cc 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -19,18 +19,22 @@ import { GeminiClient, SendMessageType } from './client.js'; import { findCompressSplitPoint } from '../services/chatCompressionService.js'; import { AuthType, + createContentGenerator, type ContentGenerator, type ContentGeneratorConfig, } from './contentGenerator.js'; +import { buildAgentContentGeneratorConfig } from '../models/content-generator-config.js'; import { type GeminiChat } from './geminiChat.js'; import type { Config } from '../config/config.js'; import { ApprovalMode } from '../config/config.js'; -import { - CompressionStatus, - GeminiEventType, - Turn, - type ChatCompressionInfo, -} from './turn.js'; +import type { ModelsConfig } from '../models/modelsConfig.js'; +import { retryWithBackoff } from '../utils/retry.js'; +import { CompressionStatus, GeminiEventType, Turn } from './turn.js'; + +vi.mock('../utils/retry.js', () => ({ + retryWithBackoff: vi.fn(async (fn) => await fn()), + isUnattendedMode: vi.fn(() => false), +})); import { getCoreSystemPrompt, getCustomSystemPrompt } from './prompts.js'; import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; @@ -90,6 +94,25 @@ vi.mock('./turn', async (importOriginal) => { vi.mock('../config/config.js'); vi.mock('./prompts'); +vi.mock('../models/content-generator-config.js', async (importOriginal) => { + const actual = + await importOriginal< + typeof import('../models/content-generator-config.js') + >(); + return { + ...actual, + buildAgentContentGeneratorConfig: vi + .fn() + .mockImplementation(actual.buildAgentContentGeneratorConfig), + }; +}); +vi.mock('./contentGenerator.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + createContentGenerator: vi.fn(), + }; +}); vi.mock('../utils/getFolderStructure', () => ({ getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'), })); @@ -133,6 +156,8 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({ const mockUiTelemetryService = vi.hoisted(() => ({ setLastPromptTokenCount: vi.fn(), getLastPromptTokenCount: vi.fn(), + reset: vi.fn(), + addEvent: vi.fn(), })); vi.mock('../telemetry/index.js', async (importOriginal) => { @@ -151,6 +176,7 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({ vi.mock('../telemetry/loggers.js', () => ({ logChatCompression: vi.fn(), logNextSpeakerCheck: vi.fn(), + logApiRequest: vi.fn(), })); // Mock RequestTokenizer to use simple character-based estimation @@ -235,15 +261,16 @@ describe('findCompressSplitPoint', () => { expect(findCompressSplitPoint(history, 0.8)).toBe(4); }); - it('should return earlier splitpoint if no valid ones are after threshhold', () => { + it('compresses everything before the trailing in-flight functionCall', () => { const history: Content[] = [ { role: 'user', parts: [{ text: 'This is the first message.' }] }, { role: 'model', parts: [{ text: 'This is the second message.' }] }, { role: 'user', parts: [{ text: 'This is the third message.' }] }, { role: 'model', parts: [{ functionCall: {} }] }, ]; - // Can't return 4 because the previous item has a function call. - expect(findCompressSplitPoint(history, 0.99)).toBe(2); + // Trailing m+fc is in-flight; the in-flight fallback compresses + // everything except the trailing fc (no preceding pair to retain). + expect(findCompressSplitPoint(history, 0.99)).toBe(3); }); it('should handle a history with only one item', () => { @@ -277,6 +304,12 @@ describe('Gemini Client (client.ts)', () => { vi.resetAllMocks(); vi.mocked(uiTelemetryService.setLastPromptTokenCount).mockClear(); + // Default: createContentGenerator rejects (simulates test env without auth). + // Individual tests can override with mockResolvedValue for success path. + vi.mocked(createContentGenerator).mockRejectedValue( + new Error('no auth in test env'), + ); + mockMemoryManager = { scheduleExtract: vi.fn().mockResolvedValue({ touchedTopics: [], @@ -389,6 +422,9 @@ describe('Gemini Client (client.ts)', () => { getArenaAgentClient: vi.fn().mockReturnValue(null), getManagedAutoMemoryEnabled: vi.fn().mockReturnValue(true), getMemoryManager: vi.fn().mockReturnValue(mockMemoryManager), + getModelsConfig: vi.fn().mockReturnValue({ + getResolvedModel: vi.fn().mockReturnValue(undefined), + }), getDisableAllHooks: vi.fn().mockReturnValue(true), getArenaManager: vi.fn().mockReturnValue(null), getMessageBus: vi.fn().mockReturnValue(undefined), @@ -408,12 +444,48 @@ describe('Gemini Client (client.ts)', () => { client = new GeminiClient(mockConfig); await client.initialize(); vi.mocked(mockConfig.getGeminiClient).mockReturnValue(client); + + // GeminiClient.sendMessageStream calls this.tryCompressChat (which now + // delegates to chat.tryCompress) before each turn. Most tests use a + // hand-rolled chat mock that doesn't implement tryCompress; default the + // wrapper to a NOOP so those tests don't crash. Tests that exercise + // compression directly (the delegation tests below, the + // emits-compression-event test) override this spy. + vi.spyOn(client, 'tryCompressChat').mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }); }); afterEach(() => { vi.restoreAllMocks(); }); + describe('initialize', () => { + it('seeds resumed chat with replayed prompt token count', async () => { + vi.mocked(mockConfig.getResumedSessionData).mockReturnValue({ + conversation: { + sessionId: 'resumed-session-id', + projectHash: 'project-hash', + startTime: new Date(0).toISOString(), + lastUpdated: new Date(0).toISOString(), + messages: [], + }, + filePath: '/test/session.jsonl', + lastCompletedUuid: null, + }); + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + 123_456, + ); + + const resumedClient = new GeminiClient(mockConfig); + await resumedClient.initialize(); + + expect(resumedClient.getChat().getLastPromptTokenCount()).toBe(123_456); + }); + }); + describe('addHistory', () => { it('should call chat.addHistory with the provided content', async () => { const mockChat = { @@ -604,6 +676,11 @@ describe('Gemini Client (client.ts)', () => { mockChat = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + tryCompress: vi.fn().mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }), }; client['chat'] = mockChat as GeminiChat; }); @@ -730,66 +807,73 @@ describe('Gemini Client (client.ts)', () => { }); }); - describe('tryCompressChat', () => { - const mockGetHistory = vi.fn(); - + // tryCompressChat is now a thin wrapper around GeminiChat.tryCompress. + // The compression logic itself is exercised in chatCompressionService.test.ts + // (token math, threshold checks, hook firing) and geminiChat.test.ts (history + // mutation, recording, hasFailedCompressionAttempt). The tests below cover + // only what the wrapper itself adds: argument forwarding and the IDE-context + // flag flip. + describe('tryCompressChat (delegation)', () => { beforeEach(() => { - client['chat'] = { - getHistory: mockGetHistory, - addHistory: vi.fn(), - setHistory: vi.fn(), - } as unknown as GeminiChat; + // The top-level beforeEach stubs tryCompressChat to NOOP for unrelated + // tests; restore the real implementation here so we can observe it. + vi.mocked(client.tryCompressChat).mockRestore(); }); - function setup({ - chatHistory = [ - { role: 'user', parts: [{ text: 'Long conversation' }] }, - { role: 'model', parts: [{ text: 'Long response' }] }, - ] as Content[], - originalTokenCount = 1000, - summaryText = 'This is a summary.', - // Token counts returned in usageMetadata to simulate what the API would return - // Default values ensure successful compression: - // newTokenCount = originalTokenCount - (compressionInputTokenCount - 1000) + compressionOutputTokenCount - // = 1000 - (1600 - 1000) + 50 = 1000 - 600 + 50 = 450 (< 1000, success) - compressionInputTokenCount = 1600, - compressionOutputTokenCount = 50, - } = {}) { - const mockOriginalChat: Partial = { - getHistory: vi.fn((_curated?: boolean) => chatHistory), - setHistory: vi.fn(), - }; - client['chat'] = mockOriginalChat as GeminiChat; + it('forwards prompt id, model, force, and signal to chat.tryCompress', async () => { + const tryCompress = vi.fn().mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }); + client['chat'] = { + tryCompress, + getHistory: vi.fn().mockReturnValue([]), + } as unknown as GeminiChat; + vi.mocked(mockConfig.getModel).mockReturnValue('the-model'); + const signal = new AbortController().signal; - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - originalTokenCount, - ); + await client.tryCompressChat('p1', true, signal); - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - role: 'model', - parts: [{ text: summaryText }], - }, - }, - ], - usageMetadata: { - promptTokenCount: compressionInputTokenCount, - candidatesTokenCount: compressionOutputTokenCount, - totalTokenCount: - compressionInputTokenCount + compressionOutputTokenCount, - }, - } as unknown as GenerateContentResponse); + expect(tryCompress).toHaveBeenCalledWith('p1', 'the-model', true, signal); + }); - // Calculate what the new history will be - const splitPoint = findCompressSplitPoint(chatHistory, 0.7); // 1 - 0.3 - const historyToKeep = chatHistory.slice(splitPoint); + it('flips forceFullIdeContext on a successful compression', async () => { + client['chat'] = { + tryCompress: vi.fn().mockResolvedValue({ + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, + }), + getHistory: vi.fn().mockReturnValue([]), + } as unknown as GeminiChat; + client['forceFullIdeContext'] = false; - // This is the history that the new chat will have. - // It includes the default startChat history + the extra history from tryCompressChat - const newCompressedHistory: Content[] = [ - // Mocked envParts + canned response from startChat + await client.tryCompressChat('p2'); + + expect(client['forceFullIdeContext']).toBe(true); + }); + + it('re-prepends startup context and seeds the new chat after compression', async () => { + const compressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'summary' }] }, + { role: 'model', parts: [{ text: 'ok' }] }, + ]; + const originalChat = client.getChat(); + vi.spyOn(originalChat, 'tryCompress').mockImplementation(async () => { + originalChat.setHistory(compressedHistory); + return { + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, + }; + }); + client['forceFullIdeContext'] = false; + + await client.tryCompressChat('p4'); + + expect(client.getChat()).not.toBe(originalChat); + expect(client.getHistory()).toEqual([ { role: 'user', parts: [{ text: 'Mocked env context' }], @@ -798,616 +882,71 @@ describe('Gemini Client (client.ts)', () => { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }], }, - // extraHistory from tryCompressChat - { - role: 'user', - parts: [{ text: summaryText }], - }, - { - role: 'model', - parts: [{ text: 'Got it. Thanks for the additional context!' }], - }, - ...historyToKeep, - ]; + ...compressedHistory, + ]); + expect(client.getChat().getLastPromptTokenCount()).toBe(200); + expect(client['forceFullIdeContext']).toBe(true); + }); - const mockNewChat: Partial = { - getHistory: vi.fn().mockReturnValue(newCompressedHistory), - setHistory: vi.fn(), - }; - - client['startChat'] = vi.fn().mockImplementation(async () => { - client['chat'] = mockNewChat as GeminiChat; - return mockNewChat as GeminiChat; - }); - - // New token count formula: originalTokenCount - (compressionInputTokenCount - 1000) + compressionOutputTokenCount - const estimatedNewTokenCount = Math.max( - 0, - originalTokenCount - - (compressionInputTokenCount - 1000) + - compressionOutputTokenCount, - ); - - return { - client, - mockOriginalChat, - mockNewChat, - estimatedNewTokenCount, - }; - } - - describe('when compression inflates the token count', () => { - it('allows compression to be forced/manual after a failure', async () => { - // Call 1 (Fails): Setup with token counts that will inflate - // newTokenCount = originalTokenCount - (compressionInputTokenCount - 1000) + compressionOutputTokenCount - // = 100 - (1010 - 1000) + 200 = 100 - 10 + 200 = 290 > 100 (inflation) - const longSummary = 'long summary '.repeat(100); - const { client, estimatedNewTokenCount: inflatedTokenCount } = setup({ - originalTokenCount: 100, - summaryText: longSummary, - compressionInputTokenCount: 1010, - compressionOutputTokenCount: 200, - }); - expect(inflatedTokenCount).toBeGreaterThan(100); // Ensure setup is correct - - await client.tryCompressChat('prompt-id-4', false); // Fails - - // Call 2 (Forced): Re-setup with token counts that will compress - // newTokenCount = 100 - (1100 - 1000) + 50 = 100 - 100 + 50 = 50 <= 100 (compression) - const shortSummary = 'short'; - const { estimatedNewTokenCount: compressedTokenCount } = setup({ - originalTokenCount: 100, - summaryText: shortSummary, - compressionInputTokenCount: 1100, - compressionOutputTokenCount: 50, - }); - expect(compressedTokenCount).toBeLessThanOrEqual(100); // Ensure setup is correct - - const result = await client.tryCompressChat('prompt-id-4', true); // Forced - - expect(result.compressionStatus).toBe(CompressionStatus.COMPRESSED); - expect(result.originalTokenCount).toBe(100); - // newTokenCount might be clamped to originalTokenCount due to tolerance logic - expect(result.newTokenCount).toBeLessThanOrEqual(100); - }); - - it('yields the result even if the compression inflated the tokens', async () => { - // newTokenCount = 100 - (1010 - 1000) + 200 = 100 - 10 + 200 = 290 > 100 (inflation) - const longSummary = 'long summary '.repeat(100); - const { client, estimatedNewTokenCount } = setup({ - originalTokenCount: 100, - summaryText: longSummary, - compressionInputTokenCount: 1010, - compressionOutputTokenCount: 200, - }); - expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct - - // Mock contextWindowSize to ensure compression is triggered - vi.spyOn(client['config'], 'getContentGeneratorConfig').mockReturnValue( - { - model: 'test-model', - apiKey: 'test-key', - vertexai: false, - authType: AuthType.USE_GEMINI, - contextWindowSize: 100, // Set to same as originalTokenCount to ensure threshold is exceeded - }, - ); - - const result = await client.tryCompressChat('prompt-id-4', false); - - expect(result.compressionStatus).toBe( - CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, - ); - expect(result.originalTokenCount).toBe(100); - // The newTokenCount should be higher than original since compression failed due to inflation - expect(result.newTokenCount).toBeGreaterThan(100); - // IMPORTANT: The change in client.ts means setLastPromptTokenCount is NOT called on failure - expect( - uiTelemetryService.setLastPromptTokenCount, - ).not.toHaveBeenCalled(); - }); - - it('does not manipulate the source chat', async () => { - // newTokenCount = 100 - (1010 - 1000) + 200 = 100 - 10 + 200 = 290 > 100 (inflation) - const longSummary = 'long summary '.repeat(100); - const { client, mockOriginalChat, estimatedNewTokenCount } = setup({ - originalTokenCount: 100, - summaryText: longSummary, - compressionInputTokenCount: 1010, - compressionOutputTokenCount: 200, - }); - expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct - - await client.tryCompressChat('prompt-id-4', false); - - // On failure, the chat should NOT be replaced - expect(client['chat']).toBe(mockOriginalChat); - }); - - it('will not attempt to compress context after a failure', async () => { - // newTokenCount = 100 - (1010 - 1000) + 200 = 100 - 10 + 200 = 290 > 100 (inflation) - const longSummary = 'long summary '.repeat(100); - const { client, estimatedNewTokenCount } = setup({ - originalTokenCount: 100, - summaryText: longSummary, - compressionInputTokenCount: 1010, - compressionOutputTokenCount: 200, - }); - expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct - - // Mock contextWindowSize to ensure compression is triggered - vi.spyOn(client['config'], 'getContentGeneratorConfig').mockReturnValue( - { - model: 'test-model', - apiKey: 'test-key', - vertexai: false, - authType: AuthType.USE_GEMINI, - contextWindowSize: 100, // Set to same as originalTokenCount to ensure threshold is exceeded - }, - ); - - await client.tryCompressChat('prompt-id-4', false); // This fails and sets hasFailedCompressionAttempt = true - - // This call should now be a NOOP - const result = await client.tryCompressChat('prompt-id-5', false); - - // generateContent (for summary) should only have been called once - expect(mockGenerateContentFn).toHaveBeenCalledTimes(1); - expect(result).toEqual({ - compressionStatus: CompressionStatus.NOOP, - newTokenCount: 0, + it('does not flip forceFullIdeContext when compression NOOPs', async () => { + client['chat'] = { + tryCompress: vi.fn().mockResolvedValue({ originalTokenCount: 0, - }); - }); - }); - - it('should not trigger summarization if token count is below threshold', async () => { - const MOCKED_TOKEN_LIMIT = 1000; - vi.spyOn(client['config'], 'getContentGeneratorConfig').mockReturnValue({ - model: 'test-model', - apiKey: 'test-key', - vertexai: false, - authType: AuthType.USE_GEMINI, - contextWindowSize: MOCKED_TOKEN_LIMIT, - }); - mockGetHistory.mockReturnValue([ - { role: 'user', parts: [{ text: '...history...' }] }, - ]); - const originalTokenCount = MOCKED_TOKEN_LIMIT * 0.699; - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - originalTokenCount, - ); - - const initialChat = client.getChat(); - const result = await client.tryCompressChat('prompt-id-2', false); - const newChat = client.getChat(); - - expect(result).toEqual({ - compressionStatus: CompressionStatus.NOOP, - newTokenCount: originalTokenCount, - originalTokenCount, - }); - expect(newChat).toBe(initialChat); - }); - - it('logs a telemetry event when compressing', async () => { - const { logChatCompression } = await import('../telemetry/loggers.js'); - vi.mocked(logChatCompression).mockClear(); - - const MOCKED_TOKEN_LIMIT = 1000; - const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5; - vi.spyOn(client['config'], 'getContentGeneratorConfig').mockReturnValue({ - model: 'test-model', - apiKey: 'test-key', - vertexai: false, - authType: AuthType.USE_GEMINI, - contextWindowSize: MOCKED_TOKEN_LIMIT, - }); - vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({ - contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD, - }); - // Need multiple history items so there's something to compress - const history = [ - { role: 'user', parts: [{ text: '...history 1...' }] }, - { role: 'model', parts: [{ text: '...history 2...' }] }, - { role: 'user', parts: [{ text: '...history 3...' }] }, - { role: 'model', parts: [{ text: '...history 4...' }] }, - ]; - mockGetHistory.mockReturnValue(history); - - // Token count needs to be ABOVE the threshold to trigger compression - const originalTokenCount = - MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD + 1; - - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - originalTokenCount, - ); - - // Mock the summary response from the chat - // newTokenCount = 501 - (1400 - 1000) + 50 = 501 - 400 + 50 = 151 <= 501 (success) - const summaryText = 'This is a summary.'; - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - role: 'model', - parts: [{ text: summaryText }], - }, - }, - ], - usageMetadata: { - promptTokenCount: 1400, - candidatesTokenCount: 50, - totalTokenCount: 1450, - }, - } as unknown as GenerateContentResponse); - - // Mock startChat to complete the compression flow - const splitPoint = findCompressSplitPoint(history, 0.7); - const historyToKeep = history.slice(splitPoint); - const newCompressedHistory: Content[] = [ - { role: 'user', parts: [{ text: 'Mocked env context' }] }, - { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, - { role: 'user', parts: [{ text: summaryText }] }, - { - role: 'model', - parts: [{ text: 'Got it. Thanks for the additional context!' }], - }, - ...historyToKeep, - ]; - const mockNewChat: Partial = { - getHistory: vi.fn().mockReturnValue(newCompressedHistory), - }; - client['startChat'] = vi - .fn() - .mockResolvedValue(mockNewChat as GeminiChat); - - await client.tryCompressChat('prompt-id-3', false); - - expect(logChatCompression).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - tokens_before: originalTokenCount, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, }), - ); - expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalled(); + getHistory: vi.fn().mockReturnValue([]), + } as unknown as GeminiChat; + client['forceFullIdeContext'] = false; + + await client.tryCompressChat('p3'); + + expect(client['forceFullIdeContext']).toBe(false); }); - it('should trigger summarization if token count is above threshold with contextPercentageThreshold setting', async () => { - const MOCKED_TOKEN_LIMIT = 1000; - const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5; - vi.spyOn(client['config'], 'getContentGeneratorConfig').mockReturnValue({ - model: 'test-model', - apiKey: 'test-key', - vertexai: false, - authType: AuthType.USE_GEMINI, - contextWindowSize: MOCKED_TOKEN_LIMIT, + it('flips forceFullIdeContext when ChatCompressed flows through sendMessageStream', async () => { + // Auto-compaction lives inside chat.sendMessageStream and surfaces via + // the compressed → ChatCompressed bridge in turn.ts. The flip on this + // path is owned by the for-await loop in client.sendMessageStream, not + // by tryCompressChat — so this test feeds the event in directly. + vi.spyOn(client, 'tryCompressChat').mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, }); - vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({ - contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD, - }); - // Need multiple history items so there's something to compress - const history = [ - { role: 'user', parts: [{ text: '...history 1...' }] }, - { role: 'model', parts: [{ text: '...history 2...' }] }, - { role: 'user', parts: [{ text: '...history 3...' }] }, - { role: 'model', parts: [{ text: '...history 4...' }] }, - ]; - mockGetHistory.mockReturnValue(history); - - // Token count needs to be ABOVE the threshold to trigger compression - const originalTokenCount = - MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD + 1; - - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - originalTokenCount, - ); - - // Mock summary and new chat - const summaryText = 'This is a summary.'; - const splitPoint = findCompressSplitPoint(history, 0.7); - const historyToKeep = history.slice(splitPoint); - const newCompressedHistory: Content[] = [ - { role: 'user', parts: [{ text: 'Mocked env context' }] }, - { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, - { role: 'user', parts: [{ text: summaryText }] }, - { - role: 'model', - parts: [{ text: 'Got it. Thanks for the additional context!' }], - }, - ...historyToKeep, - ]; - const mockNewChat: Partial = { - getHistory: vi.fn().mockReturnValue(newCompressedHistory), - }; - client['startChat'] = vi.fn().mockImplementation(async () => { - client['chat'] = mockNewChat as GeminiChat; - return mockNewChat as GeminiChat; - }); - - // Mock the summary response from the chat - // newTokenCount = 501 - (1400 - 1000) + 50 = 501 - 400 + 50 = 151 <= 501 (success) - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - role: 'model', - parts: [{ text: summaryText }], + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { + type: GeminiEventType.ChatCompressed, + value: { + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, }, - }, - ], - usageMetadata: { - promptTokenCount: 1400, - candidatesTokenCount: 50, - totalTokenCount: 1450, - }, - } as unknown as GenerateContentResponse); - - const initialChat = client.getChat(); - const result = await client.tryCompressChat('prompt-id-3', false); - const newChat = client.getChat(); - - expect(mockGenerateContentFn).toHaveBeenCalled(); - - // Assert that summarization happened - expect(result.compressionStatus).toBe(CompressionStatus.COMPRESSED); - expect(result.originalTokenCount).toBe(originalTokenCount); - // newTokenCount might be clamped to originalTokenCount due to tolerance logic - expect(result.newTokenCount).toBeLessThanOrEqual(originalTokenCount); - - // Assert that the chat was reset - expect(newChat).not.toBe(initialChat); - }); - - it('should not compress across a function call response', async () => { - const MOCKED_TOKEN_LIMIT = 1000; - vi.spyOn(client['config'], 'getContentGeneratorConfig').mockReturnValue({ - model: 'test-model', - apiKey: 'test-key', - vertexai: false, - authType: AuthType.USE_GEMINI, - contextWindowSize: MOCKED_TOKEN_LIMIT, - }); - const history: Content[] = [ - { role: 'user', parts: [{ text: '...history 1...' }] }, - { role: 'model', parts: [{ text: '...history 2...' }] }, - { role: 'user', parts: [{ text: '...history 3...' }] }, - { role: 'model', parts: [{ text: '...history 4...' }] }, - { role: 'user', parts: [{ text: '...history 5...' }] }, - { role: 'model', parts: [{ text: '...history 6...' }] }, - { role: 'user', parts: [{ text: '...history 7...' }] }, - { role: 'model', parts: [{ text: '...history 8...' }] }, - // Normally we would break here, but we have a function response. - { - role: 'user', - parts: [{ functionResponse: { name: '...history 8...' } }], - }, - { role: 'model', parts: [{ text: '...history 10...' }] }, - // Instead we will break here. - { role: 'user', parts: [{ text: '...history 10...' }] }, - ]; - mockGetHistory.mockReturnValue(history); - - const originalTokenCount = 1000 * 0.7; - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - originalTokenCount, + }; + })(), ); + client['chat'] = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + } as unknown as GeminiChat; + client['forceFullIdeContext'] = false; - // Mock summary and new chat - const summaryText = 'This is a summary.'; - const splitPoint = findCompressSplitPoint(history, 0.7); // This should be 10 - expect(splitPoint).toBe(10); // Verify split point logic - const historyToKeep = history.slice(splitPoint); // Should keep last user message - expect(historyToKeep).toEqual([ - { role: 'user', parts: [{ text: '...history 10...' }] }, - ]); - - const newCompressedHistory: Content[] = [ - { role: 'user', parts: [{ text: 'Mocked env context' }] }, - { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, - { role: 'user', parts: [{ text: summaryText }] }, - { - role: 'model', - parts: [{ text: 'Got it. Thanks for the additional context!' }], - }, - ...historyToKeep, - ]; - const mockNewChat: Partial = { - getHistory: vi.fn().mockReturnValue(newCompressedHistory), - }; - client['startChat'] = vi.fn().mockImplementation(async () => { - client['chat'] = mockNewChat as GeminiChat; - return mockNewChat as GeminiChat; - }); - - // Mock the summary response from the chat - // newTokenCount = 700 - (1500 - 1000) + 50 = 700 - 500 + 50 = 250 <= 700 (success) - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - role: 'model', - parts: [{ text: summaryText }], - }, - }, - ], - usageMetadata: { - promptTokenCount: 1500, - candidatesTokenCount: 50, - totalTokenCount: 1550, - }, - } as unknown as GenerateContentResponse); - - const initialChat = client.getChat(); - const result = await client.tryCompressChat('prompt-id-3', false); - const newChat = client.getChat(); - - expect(mockGenerateContentFn).toHaveBeenCalled(); - - // Assert that summarization happened - expect(result.compressionStatus).toBe(CompressionStatus.COMPRESSED); - expect(result.originalTokenCount).toBe(originalTokenCount); - // newTokenCount might be clamped to originalTokenCount due to tolerance logic - expect(result.newTokenCount).toBeLessThanOrEqual(originalTokenCount); - // Assert that the chat was reset - expect(newChat).not.toBe(initialChat); - - // 1. standard start context message (env) - // 2. standard canned model response - // 3. compressed summary message (user) - // 4. standard canned model response - // 5. The last user message (historyToKeep) - expect(newChat.getHistory().length).toEqual(5); - }); - - it('should always trigger summarization when force is true, regardless of token count', async () => { - // Need multiple history items so there's something to compress - const history = [ - { role: 'user', parts: [{ text: '...history 1...' }] }, - { role: 'model', parts: [{ text: '...history 2...' }] }, - { role: 'user', parts: [{ text: '...history 3...' }] }, - { role: 'model', parts: [{ text: '...history 4...' }] }, - ]; - mockGetHistory.mockReturnValue(history); - - const originalTokenCount = 100; // Well below threshold, but > estimated new count - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - originalTokenCount, + const stream = client.sendMessageStream( + [{ text: 'hi' }], + new AbortController().signal, + 'prompt-auto-flip', + { type: SendMessageType.UserQuery }, ); + for await (const _ of stream) { + /* drain */ + } - // Mock summary and new chat - const summaryText = 'This is a summary.'; - const splitPoint = findCompressSplitPoint(history, 0.7); - const historyToKeep = history.slice(splitPoint); - const newCompressedHistory: Content[] = [ - { role: 'user', parts: [{ text: 'Mocked env context' }] }, - { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, - { role: 'user', parts: [{ text: summaryText }] }, - { - role: 'model', - parts: [{ text: 'Got it. Thanks for the additional context!' }], - }, - ...historyToKeep, - ]; - const mockNewChat: Partial = { - getHistory: vi.fn().mockReturnValue(newCompressedHistory), - }; - client['startChat'] = vi.fn().mockImplementation(async () => { - client['chat'] = mockNewChat as GeminiChat; - return mockNewChat as GeminiChat; - }); - - // Mock the summary response from the chat - // newTokenCount = 100 - (1060 - 1000) + 20 = 100 - 60 + 20 = 60 <= 100 (success) - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - role: 'model', - parts: [{ text: summaryText }], - }, - }, - ], - usageMetadata: { - promptTokenCount: 1060, - candidatesTokenCount: 20, - totalTokenCount: 1080, - }, - } as unknown as GenerateContentResponse); - - const initialChat = client.getChat(); - const result = await client.tryCompressChat('prompt-id-1', true); // force = true - const newChat = client.getChat(); - - expect(mockGenerateContentFn).toHaveBeenCalled(); - - expect(result.compressionStatus).toBe(CompressionStatus.COMPRESSED); - expect(result.originalTokenCount).toBe(originalTokenCount); - // newTokenCount might be clamped to originalTokenCount due to tolerance logic - expect(result.newTokenCount).toBeLessThanOrEqual(originalTokenCount); - - // Assert that the chat was reset - expect(newChat).not.toBe(initialChat); + expect(client['forceFullIdeContext']).toBe(true); }); }); describe('sendMessageStream', () => { - it('emits a compression event when the context was automatically compressed', async () => { - // Arrange - mockTurnRunFn.mockReturnValue( - (async function* () { - yield { type: 'content', value: 'Hello' }; - })(), - ); - - const compressionInfo: ChatCompressionInfo = { - compressionStatus: CompressionStatus.COMPRESSED, - originalTokenCount: 1000, - newTokenCount: 500, - }; - - vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce( - compressionInfo, - ); - - // Act - const stream = client.sendMessageStream( - [{ text: 'Hi' }], - new AbortController().signal, - 'prompt-id-1', - ); - - const events = await fromAsync(stream); - - // Assert - expect(events).toContainEqual({ - type: GeminiEventType.ChatCompressed, - value: compressionInfo, - }); - }); - - it.each([ - { - compressionStatus: - CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, - }, - { compressionStatus: CompressionStatus.NOOP }, - ])( - 'does not emit a compression event when the status is $compressionStatus', - async ({ compressionStatus }) => { - // Arrange - const mockStream = (async function* () { - yield { type: 'content', value: 'Hello' }; - })(); - mockTurnRunFn.mockReturnValue(mockStream); - - const compressionInfo: ChatCompressionInfo = { - compressionStatus, - originalTokenCount: 1000, - newTokenCount: 500, - }; - - vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce( - compressionInfo, - ); - - // Act - const stream = client.sendMessageStream( - [{ text: 'Hi' }], - new AbortController().signal, - 'prompt-id-1', - ); - - const events = await fromAsync(stream); - - // Assert - expect(events).not.toContainEqual({ - type: GeminiEventType.ChatCompressed, - value: expect.anything(), - }); - }, - ); - it('should include editor context when ideMode is enabled', async () => { // Arrange vi.mocked(ideContextStore.get).mockReturnValue({ @@ -1711,6 +1250,170 @@ hello ); }); + it('should not block the main request when auto-memory recall is slow', async () => { + // Simulate a recall that takes longer than the 2.5s deadline + mockMemoryManager.recall.mockReturnValue( + new Promise((resolve) => + setTimeout( + () => + resolve({ + prompt: '## Relevant memory\n\nSlow memory result.', + selectedDocs: [], + strategy: 'model', + }), + 10_000, + ), + ), + ); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + vi.useFakeTimers(); + try { + const streamPromise = (async () => { + const stream = client.sendMessageStream( + [{ text: 'Quick question' }], + new AbortController().signal, + 'prompt-id-slow-memory', + ); + for await (const _ of stream) { + // consume stream + } + })(); + + // Advance past the 2.5s deadline — the main request should proceed + await vi.advanceTimersByTimeAsync(3_000); + await streamPromise; + + // The main request should have been called without the slow memory + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'test-model', + expect.not.arrayContaining([ + expect.stringContaining('Slow memory result'), + ]), + expect.any(AbortSignal), + ); + } finally { + vi.useRealTimers(); + } + }); + + it('should include auto-memory prompt when recall completes within deadline', async () => { + // Simulate a fast recall that completes well within the deadline + mockMemoryManager.recall.mockResolvedValue({ + prompt: '## Relevant memory\n\nFast memory result.', + selectedDocs: [], + strategy: 'heuristic', + }); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const stream = client.sendMessageStream( + [{ text: 'Quick question' }], + new AbortController().signal, + 'prompt-id-fast-memory', + ); + for await (const _ of stream) { + // consume stream + } + + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'test-model', + expect.arrayContaining(['## Relevant memory\n\nFast memory result.']), + expect.any(AbortSignal), + ); + }); + + it('should proceed without auto-memory when managed auto-memory is disabled', async () => { + // When getManagedAutoMemoryEnabled returns false, no recall is initiated + // and sendMessageStream completes without memory content + vi.mocked(mockConfig.getManagedAutoMemoryEnabled).mockReturnValue(false); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const stream = client.sendMessageStream( + [{ text: 'Quick question' }], + new AbortController().signal, + 'prompt-id-no-memory', + ); + for await (const _ of stream) { + // consume stream + } + + // recall should never have been called + expect(mockMemoryManager.recall).not.toHaveBeenCalled(); + + // The main request should have been called without any memory content + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'test-model', + ['Quick question'], + expect.any(AbortSignal), + ); + + // Restore default + vi.mocked(mockConfig.getManagedAutoMemoryEnabled).mockReturnValue(true); + }); + + it('should proceed normally when recall rejects', async () => { + // Simulate a recall that throws — the .catch() handler should swallow + // the error and the main request should complete without memory content + mockMemoryManager.recall.mockRejectedValue(new Error('recall failed')); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const stream = client.sendMessageStream( + [{ text: 'Quick question' }], + new AbortController().signal, + 'prompt-id-recall-fail', + ); + for await (const _ of stream) { + // consume stream + } + + // The main request should have been called without any memory content + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'test-model', + ['Quick question'], + expect.any(AbortSignal), + ); + }); + it('should run managed auto-memory extraction after a completed user query', async () => { mockMemoryManager.scheduleExtract.mockResolvedValue({ touchedTopics: ['user'], @@ -1998,6 +1701,89 @@ Other open files: expect(mockTurnRunFn).toHaveBeenCalledTimes(MAX_SESSION_TURNS); }); + it('should abort the pending recall when MaxSessionTurns is hit', async () => { + vi.spyOn(client['config'], 'getMaxSessionTurns').mockReturnValue(1); + client['sessionTurnCount'] = 1; // already at limit; next call exceeds it + + const abortHandler = vi.fn(); + mockMemoryManager.recall.mockImplementation((_root, _query, opts) => { + opts.abortSignal?.addEventListener('abort', abortHandler); + return new Promise(() => {}); // never resolves + }); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const stream = client.sendMessageStream( + [{ text: 'over the limit' }], + new AbortController().signal, + 'prompt-id-over-limit', + ); + const events = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toEqual([{ type: GeminiEventType.MaxSessionTurns }]); + expect(abortHandler).toHaveBeenCalledTimes(1); + }); + + it('should abort the pending recall when SessionTokenLimitExceeded', async () => { + // Use a very low token limit so the (uncompressed) history exceeds it + vi.spyOn(client['config'], 'getSessionTokenLimit').mockReturnValue(1); + + // Force token count to be above the limit + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + 9999, + ); + + const abortHandler = vi.fn(); + mockMemoryManager.recall.mockImplementation((_root, _query, opts) => { + opts.abortSignal?.addEventListener('abort', abortHandler); + return new Promise(() => {}); // never resolves + }); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const stream = client.sendMessageStream( + [{ text: 'token limit test' }], + new AbortController().signal, + 'prompt-id-token-limit', + ); + const events = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toEqual([ + { + type: GeminiEventType.SessionTokenLimitExceeded, + value: expect.objectContaining({ + currentTokens: 9999, + limit: 1, + }), + }, + ]); + expect(abortHandler).toHaveBeenCalledTimes(1); + }); + it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => { // This test verifies that the infinite loop protection works even when // someone tries to bypass it by calling with a very large turns value @@ -3208,4 +2994,368 @@ Other open files: // Note: there is currently no "fallback mode" model routing; the model used // is always the one explicitly requested by the caller. }); + + describe('generateContent with fast model', () => { + it('should resolve per-model config and fall back when createContentGenerator fails', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortSignal = new AbortController().signal; + + // Set up a resolved model for the fast model, but createContentGenerator + // will fail in the test env (no auth), so it falls back to the main + // content generator. Verify the resolution was attempted. + const mockResolvedModel = { + id: 'fast-model', + authType: 'openai' as const, + name: 'Fast Model', + baseUrl: 'https://fast-api.example.com', + generationConfig: { + extra_body: { enable_thinking: false }, + samplingParams: { temperature: 0.1 }, + }, + capabilities: {}, + }; + + const getResolvedModel = vi.fn().mockReturnValue(mockResolvedModel); + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel, + } as unknown as ModelsConfig); + + await client.generateContent( + contents, + { temperature: 0.5 }, + abortSignal, + 'fast-model', + ); + + // Verify that getResolvedModel was called with the fast model ID + expect(getResolvedModel).toHaveBeenCalledWith( + expect.any(String), + 'fast-model', + ); + + // The main content generator is used as fallback (since creating a new + // one fails in test env without auth). In production, a dedicated + // content generator with the fast model's settings would be created. + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'fast-model', + }), + expect.any(String), + ); + }); + + it('should use a dedicated content generator for the fast model on success', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortSignal = new AbortController().signal; + + // Create a mock dedicated content generator + const mockFastContentGenerator = { + generateContent: vi.fn().mockResolvedValue({ + text: 'fast response', + }), + } as unknown as ContentGenerator; + + // Set up a resolved model for the fast model + const mockResolvedModel = { + id: 'fast-model', + authType: 'openai' as const, + name: 'Fast Model', + baseUrl: 'https://fast-api.example.com', + envKey: 'FAST_API_KEY', + generationConfig: { + extra_body: { enable_thinking: false }, + samplingParams: { temperature: 0.1 }, + }, + capabilities: {}, + }; + + const getResolvedModel = vi.fn().mockReturnValue(mockResolvedModel); + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel, + } as unknown as ModelsConfig); + + // Override createContentGenerator to return our test double (success path) + vi.mocked(createContentGenerator).mockResolvedValue( + mockFastContentGenerator, + ); + + await client.generateContent( + contents, + { temperature: 0.5 }, + abortSignal, + 'fast-model', + ); + + // Verify buildAgentContentGeneratorConfig was called with correct args + expect(buildAgentContentGeneratorConfig).toHaveBeenCalledWith( + mockConfig, + 'fast-model', + expect.objectContaining({ + baseUrl: 'https://fast-api.example.com', + }), + ); + + // The dedicated fast content generator should be used + expect(mockFastContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'fast-model', + }), + expect.any(String), + ); + + // The original main content generator should NOT have been called + expect(mockContentGenerator.generateContent).not.toHaveBeenCalled(); + }); + + it('should use the main content generator when the requested model matches the main model', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortSignal = new AbortController().signal; + + const getResolvedModel = vi.fn(); + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel, + } as unknown as ModelsConfig); + + await client.generateContent( + contents, + {}, + abortSignal, + 'test-model', // same as getModel() return value + ); + + // getResolvedModel should NOT be called when model matches main + expect(getResolvedModel).not.toHaveBeenCalled(); + + // The main content generator should be used directly + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'test-model', + }), + expect.any(String), + ); + }); + + it('should fall back to main generator when model is not in registry', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortSignal = new AbortController().signal; + + // getResolvedModel returns undefined — model not found in registry + const getResolvedModel = vi.fn().mockReturnValue(undefined); + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel, + } as unknown as ModelsConfig); + + // Should not throw — falls back to main generator + await expect( + client.generateContent( + contents, + { temperature: 0.5 }, + abortSignal, + 'unknown-model', + ), + ).resolves.toBeDefined(); + + // getResolvedModel was called to look up the model + expect(getResolvedModel).toHaveBeenCalledWith( + expect.any(String), + 'unknown-model', + ); + + // The main content generator is used as fallback + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'unknown-model', + }), + expect.any(String), + ); + + // buildAgentContentGeneratorConfig must NOT be called when the model is + // not in the registry — the fallback path skips config construction. + expect(buildAgentContentGeneratorConfig).not.toHaveBeenCalled(); + }); + + it('should use fast model authType for retry, not main model authType', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortSignal = new AbortController().signal; + + const mockResolvedModel = { + id: 'fast-model', + authType: 'openai' as const, + name: 'Fast Model', + baseUrl: 'https://fast-api.example.com', + generationConfig: {}, + capabilities: {}, + }; + + const getResolvedModel = vi.fn().mockReturnValue(mockResolvedModel); + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel, + } as unknown as ModelsConfig); + + // Main config uses a different authType + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.QWEN_OAUTH, + apiKey: 'test-key', + apiModel: 'test-model', + } as unknown as ContentGeneratorConfig); + + // Success path for createContentGenerator + vi.mocked(createContentGenerator).mockResolvedValue(mockContentGenerator); + + await client.generateContent( + contents, + { temperature: 0.5 }, + abortSignal, + 'fast-model', + ); + + // VERIFY: retryWithBackoff was called with the fast model's authType ('openai'), + // not the main model's authType ('QWEN_OAUTH'). + expect(retryWithBackoff).toHaveBeenCalledWith( + expect.any(Function), + expect.objectContaining({ + authType: 'openai', + }), + ); + }); + + it('should cache per-model content generators', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortController = new AbortController(); + const mockResolvedModel = { + id: 'fast-model', + authType: 'openai' as const, + name: 'Fast Model', + baseUrl: 'https://fast-api.example.com', + generationConfig: {}, + capabilities: {}, + }; + + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel: vi.fn().mockReturnValue(mockResolvedModel), + } as unknown as ModelsConfig); + + vi.mocked(createContentGenerator).mockResolvedValue(mockContentGenerator); + + // First call + await client.generateContent( + contents, + {}, + abortController.signal, + 'fast-model', + ); + expect(createContentGenerator).toHaveBeenCalledTimes(1); + + // Second call - should use cache + await client.generateContent( + contents, + {}, + abortController.signal, + 'fast-model', + ); + expect(createContentGenerator).toHaveBeenCalledTimes(1); + }); + + it('should resolve model across authTypes when main authType misses', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortSignal = new AbortController().signal; + + const mockResolvedModel = { + id: 'fast-model', + authType: 'openai' as const, + name: 'Fast Model', + baseUrl: 'https://fast-api.example.com', + generationConfig: {}, + capabilities: {}, + envKey: undefined, + }; + + // resolveModelAcrossAuthTypes calls getResolvedModel multiple times: + // 1. main authType (QWEN_OAUTH) → undefined (miss) + // 2. secondary authType (USE_OPENAI) → mockResolvedModel (hit) + // 3. buildAgentContentGeneratorConfig calls getResolvedModel again + // with the resolved authType → mockResolvedModel (hit) + const getResolvedModel = vi + .fn() + .mockReturnValueOnce(undefined) + .mockReturnValue(mockResolvedModel); + + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel, + } as unknown as ModelsConfig); + + // Main config uses QWEN_OAUTH — fast model registered under USE_OPENAI + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.QWEN_OAUTH, + apiKey: 'test-key', + apiModel: 'test-model', + } as unknown as ContentGeneratorConfig); + + // Mock createContentGenerator to succeed so the cross-authType + // resolution path completes without falling back + vi.mocked(createContentGenerator).mockResolvedValue(mockContentGenerator); + + await client.generateContent( + contents, + { temperature: 0.5 }, + abortSignal, + 'fast-model', + ); + + // First call uses main authType (QWEN_OAUTH) — misses + expect(getResolvedModel).toHaveBeenNthCalledWith( + 1, + AuthType.QWEN_OAUTH, + 'fast-model', + ); + // Second call falls through to secondary authType — hits + expect(getResolvedModel).toHaveBeenNthCalledWith( + 2, + AuthType.USE_OPENAI, + 'fast-model', + ); + // Generator was created using the resolved model's config + expect(createContentGenerator).toHaveBeenCalled(); + }); + + it('should clear per-model generator cache on resetChat', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const abortController = new AbortController(); + const mockResolvedModel = { + id: 'fast-model', + authType: 'openai' as const, + name: 'Fast Model', + baseUrl: 'https://fast-api.example.com', + generationConfig: {}, + capabilities: {}, + }; + + vi.mocked(mockConfig.getModelsConfig).mockReturnValue({ + getResolvedModel: vi.fn().mockReturnValue(mockResolvedModel), + } as unknown as ModelsConfig); + + vi.mocked(createContentGenerator).mockResolvedValue(mockContentGenerator); + + // First call — populates cache + await client.generateContent( + contents, + {}, + abortController.signal, + 'fast-model', + ); + expect(createContentGenerator).toHaveBeenCalledTimes(1); + + // Reset chat should clear the cache + await client.resetChat(); + + // Second call after reset — cache should be cleared, generator recreated + await client.generateContent( + contents, + {}, + abortController.signal, + 'fast-model', + ); + expect(createContentGenerator).toHaveBeenCalledTimes(2); + }); + }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 47a9a9b0f..902d29664 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -22,6 +22,8 @@ const debugLogger = createDebugLogger('CLIENT'); // Core modules import type { ContentGenerator } from './contentGenerator.js'; +import type { ResolvedModelConfig } from '../models/types.js'; +import { AuthType, createContentGenerator } from './contentGenerator.js'; import { GeminiChat } from './geminiChat.js'; import { getArenaSystemReminder, @@ -40,13 +42,15 @@ import { // Services import { - ChatCompressionService, COMPRESSION_PRESERVE_THRESHOLD, COMPRESSION_TOKEN_THRESHOLD, } from '../services/chatCompressionService.js'; import { LoopDetectionService } from '../services/loopDetectionService.js'; import { CommitAttributionService } from '../services/commitAttribution.js'; +// Models +import { buildAgentContentGeneratorConfig } from '../models/content-generator-config.js'; + // Tools import type { RelevantAutoMemoryPromptResult } from '../memory/manager.js'; import { ToolNames } from '../tools/tool-names.js'; @@ -126,6 +130,43 @@ const EMPTY_RELEVANT_AUTO_MEMORY_RESULT: RelevantAutoMemoryPromptResult = { strategy: 'none', }; +/** + * Resolve the auto-memory recall promise with a hard deadline. + * If the recall (model-driven selection + heuristic fallback) does not complete + * within the deadline, return an empty result so the main request is not delayed. + * + * The deadline is set slightly above the model-driven selector's own + * AbortSignal.timeout (2s) to give the heuristic fallback time to complete, + * but low enough that the user does not perceive a delay on every turn. + */ +async function resolveAutoMemoryWithDeadline( + promise: Promise | undefined, + onDeadline: () => void, +): Promise { + if (!promise) { + return EMPTY_RELEVANT_AUTO_MEMORY_RESULT; + } + + let timer: ReturnType | undefined; + const deadline = new Promise((resolve) => { + timer = setTimeout(() => { + try { + onDeadline(); + } finally { + resolve(EMPTY_RELEVANT_AUTO_MEMORY_RESULT); + } + }, 2_500); + }); + + try { + return await Promise.race([promise, deadline]); + } finally { + if (timer !== undefined) { + clearTimeout(timer); + } + } +} + export class GeminiClient { private chat?: GeminiChat; private sessionTurnCount = 0; @@ -135,12 +176,15 @@ export class GeminiClient { private lastPromptId: string | undefined = undefined; private lastSentIdeContext: IdeContext | undefined; private forceFullIdeContext = true; + private pendingRecallAbortController: AbortController | undefined; /** - * At any point in this conversation, was compression triggered without - * being forced and did it fail? + * Cache of per-model ContentGenerators keyed by model ID. + * Avoids rebuilding the generator (SDK instantiation, config resolution) + * on every side query (recap, title, tool summary). + * Cleared on session reset (resetChat) to pick up config changes. */ - private hasFailedCompressionAttempt = false; + private perModelGeneratorCache = new Map>(); /** * Promises for pending background memory tasks (dream / extract). @@ -174,6 +218,9 @@ export class GeminiClient { resumedSessionData.conversation, ); await this.startChat(resumedHistory); + this.getChat().setLastPromptTokenCount( + uiTelemetryService.getLastPromptTokenCount(), + ); // Restore attribution state from the last snapshot in the session this.restoreAttributionFromSession(resumedSessionData.conversation); @@ -305,6 +352,13 @@ export class GeminiClient { // pointing at content the model can no longer retrieve. debugLogger.debug('[FILE_READ_CACHE] clear after resetChat'); this.config.getFileReadCache().clear(); + this.perModelGeneratorCache.clear(); + // Abort any in-flight auto-memory recall so the stale controller + // does not leak into the next session. + if (this.pendingRecallAbortController) { + this.pendingRecallAbortController.abort(); + this.pendingRecallAbortController = undefined; + } await this.startChat(); } @@ -345,7 +399,6 @@ export class GeminiClient { async startChat(extraHistory?: Content[]): Promise { this.forceFullIdeContext = true; - this.hasFailedCompressionAttempt = false; // Clear stale cache params on session reset to prevent cross-session leakage clearCacheSafeParams(); @@ -731,19 +784,36 @@ export class GeminiClient { messageType === SendMessageType.Cron ) { if (this.config.getManagedAutoMemoryEnabled()) { - relevantAutoMemoryPromise = this.config + const recallAbortController = new AbortController(); + const rawRecallPromise = this.config .getMemoryManager() .recall(this.config.getProjectRoot(), partToString(request), { config: this.config, excludedFilePaths: this.surfacedRelevantAutoMemoryPaths, + abortSignal: recallAbortController.signal, }) .catch((error: unknown) => { - debugLogger.warn( - 'Managed auto-memory recall prefetch failed.', - error, - ); + if (error instanceof DOMException && error.name === 'AbortError') { + debugLogger.debug( + 'Auto-memory recall aborted by deadline.', + error, + ); + } else { + debugLogger.warn( + 'Managed auto-memory recall prefetch failed.', + error, + ); + } return EMPTY_RELEVANT_AUTO_MEMORY_RESULT; }); + this.pendingRecallAbortController = recallAbortController; + // Race the recall against the deadline at initiation time so the 2.5s + // budget is not consumed by intermediate work (microcompact, compression, + // token checks, IDE context) between initiation and consumption. + relevantAutoMemoryPromise = resolveAutoMemoryWithDeadline( + rawRecallPromise, + () => recallAbortController.abort(), + ); } // Track prompt count for commit attribution. Only the user typing a @@ -814,6 +884,8 @@ export class GeminiClient { this.config.getMaxSessionTurns() > 0 && this.sessionTurnCount > this.config.getMaxSessionTurns() ) { + this.pendingRecallAbortController?.abort(); + this.pendingRecallAbortController = undefined; yield { type: GeminiEventType.MaxSessionTurns }; return new Turn(this.getChat(), prompt_id); } @@ -822,21 +894,21 @@ export class GeminiClient { // Ensure turns never exceeds MAX_TURNS to prevent infinite loops const boundedTurns = Math.min(turns, MAX_TURNS); if (!boundedTurns) { + this.pendingRecallAbortController?.abort(); + this.pendingRecallAbortController = undefined; return new Turn(this.getChat(), prompt_id); } - const compressed = await this.tryCompressChat(prompt_id, false, signal); - - if (compressed.compressionStatus === CompressionStatus.COMPRESSED) { - yield { type: GeminiEventType.ChatCompressed, value: compressed }; - } - - // Check session token limit after compression. - // `lastPromptTokenCount` is treated as authoritative for the (possibly compressed) history; + // Auto-compaction happens inside GeminiChat.sendMessageStream and surfaces + // via the `compressed → ChatCompressed` bridge in turn.ts. Manual /compress + // still calls tryCompressChat directly for the full reset (env refresh + + // forceFullIdeContext flip). const sessionTokenLimit = this.config.getSessionTokenLimit(); if (sessionTokenLimit > 0) { const lastPromptTokenCount = uiTelemetryService.getLastPromptTokenCount(); if (lastPromptTokenCount > sessionTokenLimit) { + this.pendingRecallAbortController?.abort(); + this.pendingRecallAbortController = undefined; yield { type: GeminiEventType.SessionTokenLimitExceeded, value: { @@ -887,6 +959,8 @@ export class GeminiClient { `Arena control signal received: ${controlSignal.type} - ${controlSignal.reason}`, ); await arenaAgentClient.reportCancelled(); + this.pendingRecallAbortController?.abort(); + this.pendingRecallAbortController = undefined; return new Turn(this.getChat(), prompt_id); } } @@ -903,6 +977,9 @@ export class GeminiClient { messageType === SendMessageType.Cron ) { const systemReminders = []; + // The recall promise was already raced against the 2.5s deadline at + // initiation time; this await just collects the result. + this.pendingRecallAbortController = undefined; const relevantAutoMemory = relevantAutoMemoryPromise ? await relevantAutoMemoryPromise : EMPTY_RELEVANT_AUTO_MEMORY_RESULT; @@ -971,6 +1048,13 @@ export class GeminiClient { await arenaAgentClient.updateStatus(); } + // Re-send a full IDE context blob on the next regular message — auto + // compaction inside chat.sendMessageStream may have summarized away + // the previous IDE-context turn. + if (event.type === GeminiEventType.ChatCompressed) { + this.forceFullIdeContext = true; + } + yield event; if (event.type === GeminiEventType.Error) { if (arenaAgentClient) { @@ -1193,10 +1277,31 @@ export class GeminiClient { systemInstruction: finalSystemInstruction, }; + // When the requested model differs from the main model (e.g. fast model + // side queries for session recap / title / summary), resolve the target + // model's own ContentGeneratorConfig so that per-model settings like + // extra_body, samplingParams, and reasoning are not inherited from the + // main model's config. + const mainModel = this.config.getModel() ?? model; + const isPerModel = model !== mainModel; + + // Resolve the authType for retry logic. When using a per-model content + // generator (e.g. fast model side queries), the retry authType must match + // the target model's provider, not the main session's provider. This + // ensures QWEN_OAUTH quota detection checks against the right provider. + const retryAuthType = isPerModel + ? (this.createRetryAuthTypeForModel(model) ?? + this.config.getContentGeneratorConfig()?.authType ?? + AuthType.USE_OPENAI) + : this.config.getContentGeneratorConfig()?.authType; + + const contentGenerator = isPerModel + ? await this.createContentGeneratorForModel(model) + : this.getContentGeneratorOrFail(); const apiCall = () => { currentAttemptModel = model; - return this.getContentGeneratorOrFail().generateContent( + return contentGenerator.generateContent( { model, config: requestConfig, @@ -1206,7 +1311,7 @@ export class GeminiClient { ); }; const result = await retryWithBackoff(apiCall, { - authType: this.config.getContentGeneratorConfig()?.authType, + authType: retryAuthType, persistentMode: isUnattendedMode(), signal: abortSignal, heartbeatFn: (info) => { @@ -1236,58 +1341,141 @@ export class GeminiClient { } } + /** + * Resolve a model across all authTypes. Handles the case where the target + * model is registered under a different authType than the main model + * (e.g. main=QWEN_OAUTH, fast=USE_ANTHROPIC). + * + * TODO: Move cross-authType resolution to ModelRegistry for a cleaner + * data-layer solution. Follow-up PR. + */ + + private resolveModelAcrossAuthTypes( + model: string, + ): ResolvedModelConfig | undefined { + const modelsConfig = this.config.getModelsConfig(); + const allAuthTypes: AuthType[] = [ + AuthType.QWEN_OAUTH, + AuthType.USE_OPENAI, + AuthType.USE_VERTEX_AI, + AuthType.USE_ANTHROPIC, + AuthType.USE_GEMINI, + ]; + + // Try the main authType first for early exit + const mainAuthType = this.config.getContentGeneratorConfig()?.authType; + if (mainAuthType) { + const resolved = modelsConfig.getResolvedModel(mainAuthType, model); + if (resolved) return resolved; + } + + for (const authType of allAuthTypes) { + if (authType === mainAuthType) continue; + const resolved = modelsConfig.getResolvedModel(authType, model); + if (resolved) return resolved; + } + + return undefined; + } + + /** + * Resolve the authType for a given model without creating a full generator. + * Used by retry logic to ensure provider-specific checks (e.g. QWEN_OAUTH + * quota detection) reference the correct provider. + */ + private createRetryAuthTypeForModel(model: string): string | undefined { + return this.resolveModelAcrossAuthTypes(model)?.authType; + } + + /** + * Return a ContentGenerator for a specific model (e.g. the fast model) with + * its own per-model settings from modelProviders. This prevents the main + * model's extra_body / samplingParams / reasoning from leaking into side + * queries that target a different model. + * + * Falls back to the main content generator when the target model is not in + * the registry or when creating a dedicated generator fails (e.g. in test + * environments without full auth setup). + * + * Results are cached by model ID to avoid rebuilding the generator + * (SDK instantiation, config resolution) on every side query. + */ + private async createContentGeneratorForModel( + model: string, + ): Promise { + // Check cache first (Promise coalescing to prevent redundant SDK instantiations) + const cached = this.perModelGeneratorCache.get(model); + if (cached) return cached; + + const generatorPromise = (async () => { + try { + const resolvedModel = this.resolveModelAcrossAuthTypes(model); + + if (!resolvedModel) { + debugLogger.warn( + `Model "${model}" not found in registry across all authTypes, falling back to main generator.`, + ); + return this.getContentGeneratorOrFail(); + } + + const targetConfig = buildAgentContentGeneratorConfig( + this.config, + model, + { + authType: resolvedModel.authType, + apiKey: resolvedModel.envKey + ? (process.env[resolvedModel.envKey] ?? undefined) + : undefined, + baseUrl: resolvedModel.baseUrl, + }, + ); + + return await createContentGenerator(targetConfig, this.config); + } catch (err: unknown) { + debugLogger.warn( + `Failed to create content generator for model "${model}", falling back to main generator.`, + err instanceof Error ? err.message : String(err), + ); + // On failure, delete from cache so subsequent attempts can retry. + this.perModelGeneratorCache.delete(model); + return this.getContentGeneratorOrFail(); + } + })(); + + this.perModelGeneratorCache.set(model, generatorPromise); + return generatorPromise; + } + + /** + * Wrapper around {@link GeminiChat.tryCompress} that restores main-session + * startup context after successful compaction and flips the IDE full-context + * flag for the next regular message. + */ async tryCompressChat( prompt_id: string, force: boolean = false, signal?: AbortSignal, ): Promise { - const compressionService = new ChatCompressionService(); - - const { newHistory, info } = await compressionService.compress( - this.getChat(), + const info = await this.getChat().tryCompress( prompt_id, - force, this.config.getModel(), - this.config, - this.hasFailedCompressionAttempt, + force, signal, ); - - // Handle compression result if (info.compressionStatus === CompressionStatus.COMPRESSED) { - // Success: update chat with new compressed history - if (newHistory) { - const chatRecordingService = this.config.getChatRecordingService(); - chatRecordingService?.recordChatCompression({ - info, - compressedHistory: newHistory, - }); - - await this.startChat(newHistory); - // Compaction rewrites the prompt history: prior full-Read tool - // results may have been summarised away, but the FileReadCache - // still believes those reads are "in this conversation". A - // follow-up Read could then return the file_unchanged - // placeholder pointing at content the model can no longer - // retrieve from its own context. Clear the cache so post- - // compaction Reads re-emit the bytes. - debugLogger.debug('[FILE_READ_CACHE] clear after tryCompressChat'); - this.config.getFileReadCache().clear(); - uiTelemetryService.setLastPromptTokenCount(info.newTokenCount); - this.forceFullIdeContext = true; - } - } else if ( - info.compressionStatus === - CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT || - info.compressionStatus === - CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY - ) { - // Track failed attempts (only mark as failed if not forced) - if (!force) { - this.hasFailedCompressionAttempt = true; - } + const compressedHistory = this.getChat().getHistory(); + await this.startChat(compressedHistory); + // startChat() creates a new GeminiChat without touching FileReadCache, + // so prior read_file results that were summarised away would still + // resolve to the file_unchanged placeholder. Clear so post-compaction + // Reads re-emit bytes the model can no longer see in history. + debugLogger.debug('[FILE_READ_CACHE] clear after tryCompressChat'); + this.config.getFileReadCache().clear(); + this.getChat().setLastPromptTokenCount(info.newTokenCount); + // Re-send a full IDE context blob on the next regular message — + // compression dropped the previous context turn from history. + this.forceFullIdeContext = true; } - return info; } } diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index b11448973..067ec769b 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -4586,16 +4586,19 @@ describe('CoreToolScheduler activation wiring', () => { function buildSchedulerWithSkillManager(opts: { matchAndActivateByPaths: ReturnType; skillToolPresent: boolean; + toolResult?: ToolResult; }): { scheduler: CoreToolScheduler; onAllToolCallsComplete: ReturnType; } { const fsTool = new MockTool({ name: ToolNames.READ_FILE, - execute: vi.fn().mockResolvedValue({ - llmContent: 'file contents', - returnDisplay: 'file contents', - }), + execute: vi.fn().mockResolvedValue( + opts.toolResult ?? { + llmContent: 'file contents', + returnDisplay: 'file contents', + }, + ), }); const mockToolRegistry = { // Return the fs tool when asked by name; for SkillTool, mirror the @@ -4698,6 +4701,174 @@ describe('CoreToolScheduler activation wiring', () => { expect(responseText).toContain('now available via the Skill tool'); }); + it('includes concrete result paths in skill activation candidates', async () => { + const matchAndActivateByPaths = vi.fn().mockResolvedValue(['core-helper']); + const { scheduler } = buildSchedulerWithSkillManager({ + matchAndActivateByPaths, + skillToolPresent: true, + toolResult: { + llmContent: 'glob results', + returnDisplay: 'glob results', + resultFilePaths: [ + '/proj/packages/core/src/skills/target.ts', + '/proj/packages/cli/src/other.ts', + ], + }, + }); + + await scheduler.schedule( + [ + { + callId: '1', + name: ToolNames.GLOB, + args: { pattern: '**/*.ts' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + ], + new AbortController().signal, + ); + + expect(matchAndActivateByPaths).toHaveBeenCalledWith([ + '**/*.ts', + '/proj/packages/core/src/skills/target.ts', + '/proj/packages/cli/src/other.ts', + ]); + }); + + it('deduplicates overlapping input and result paths before activation', async () => { + const matchAndActivateByPaths = vi.fn().mockResolvedValue([]); + const { scheduler } = buildSchedulerWithSkillManager({ + matchAndActivateByPaths, + skillToolPresent: true, + toolResult: { + llmContent: 'file contents', + returnDisplay: 'file contents', + resultFilePaths: ['/proj/src/App.tsx'], + }, + }); + + await scheduler.schedule( + [ + { + callId: '1', + name: ToolNames.READ_FILE, + args: { file_path: '/proj/src/App.tsx' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + ], + new AbortController().signal, + ); + + expect(matchAndActivateByPaths).toHaveBeenCalledWith(['/proj/src/App.tsx']); + }); + + it('does not unescape concrete result paths before activation', async () => { + const matchAndActivateByPaths = vi.fn().mockResolvedValue([]); + const { scheduler } = buildSchedulerWithSkillManager({ + matchAndActivateByPaths, + skillToolPresent: true, + toolResult: { + llmContent: 'glob results', + returnDisplay: 'glob results', + resultFilePaths: ['/proj/src/foo\\ bar.ts'], + }, + }); + + await scheduler.schedule( + [ + { + callId: '1', + name: ToolNames.GLOB, + args: { pattern: '**/*.ts' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + ], + new AbortController().signal, + ); + + expect(matchAndActivateByPaths).toHaveBeenCalledWith([ + '**/*.ts', + '/proj/src/foo\\ bar.ts', + ]); + }); + + it('ignores result path metadata from non-filesystem tools', async () => { + const nonFsTool = new MockTool({ + name: 'web_fetch', + execute: vi.fn().mockResolvedValue({ + llmContent: 'web results', + returnDisplay: 'web results', + resultFilePaths: ['/proj/src/App.tsx'], + }), + }); + const mockToolRegistry = { + getTool: () => nonFsTool, + ensureTool: async () => nonFsTool, + getToolByName: () => nonFsTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => nonFsTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + const matchAndActivateByPaths = vi.fn().mockResolvedValue([]); + const scheduler = new CoreToolScheduler({ + config: { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.YOLO, + getPermissionsAllow: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'gemini', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { getProjectTempDir: () => '/tmp' }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getToolRegistry: () => mockToolRegistry, + getUseModelRouter: () => false, + getGeminiClient: () => null, + getChatRecordingService: () => undefined, + getMessageBus: vi.fn().mockReturnValue(undefined), + getDisableAllHooks: vi.fn().mockReturnValue(true), + getConditionalRulesRegistry: () => undefined, + getSkillManager: () => ({ matchAndActivateByPaths }), + } as unknown as Config, + onAllToolCallsComplete: vi.fn(), + onToolCallsUpdate: vi.fn(), + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + await scheduler.schedule( + [ + { + callId: '1', + name: 'web_fetch', + args: { url: 'https://example.com' }, + isClientInitiated: false, + prompt_id: 'p1', + }, + ], + new AbortController().signal, + ); + + expect(matchAndActivateByPaths).not.toHaveBeenCalled(); + }); + it('suppresses the activation reminder when SkillTool is absent (subagent without skill in toolslist)', async () => { const matchAndActivateByPaths = vi.fn().mockResolvedValue(['tsx-helper']); const { scheduler, onAllToolCallsComplete } = diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 4ddc8c7b1..ba042287d 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -210,6 +210,14 @@ const FS_PATH_TOOL_NAMES: ReadonlySet = new Set([ ToolNames.LSP, ]); +function canonicalToolName(toolName: string): string { + return (ToolNamesMigration as Record)[toolName] ?? toolName; +} + +function isFilesystemPathTool(toolName: string): boolean { + return FS_PATH_TOOL_NAMES.has(canonicalToolName(toolName)); +} + /** * Trim trailing forward / back slashes from a path-shaped string without * a regex. The regex form `s.replace(/[\\/]+$/, '')` is functionally @@ -291,8 +299,7 @@ export function extractToolFilePaths( // The tool registry resolves these at execution time, so a tool call // like `replace({ file_path: 'src/App.tsx' })` actually runs EditTool; // gating only on the canonical name closes the alias-bypass hole. - const canonical = - (ToolNamesMigration as Record)[toolName] ?? toolName; + const canonical = canonicalToolName(toolName); if (!FS_PATH_TOOL_NAMES.has(canonical)) { // Surface allowlist gaps at debug level when a non-FS tool's input // *looks* path-shaped: we silently skip path activation for it, but @@ -1930,8 +1937,14 @@ export class CoreToolScheduler { // FS_PATH_TOOL_NAMES) so MCP / non-FS tools that reuse those // parameter names with different semantics never enter the // activation pipeline. - const candidatePaths = extractToolFilePaths(toolName, toolInput).map( - (p) => unescapePath(p), + const inputPaths = extractToolFilePaths(toolName, toolInput); + const resultPaths = + isFilesystemPathTool(toolName) && + Array.isArray(toolResult.resultFilePaths) + ? toolResult.resultFilePaths + : []; + const candidatePaths = Array.from( + new Set([...inputPaths.map((p) => unescapePath(p)), ...resultPaths]), ); if (candidatePaths.length > 0) { diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 8a45bb5d4..521e37e06 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -22,6 +22,8 @@ import { StreamContentError } from './openaiContentGenerator/pipeline.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; +import { CompressionStatus, type ChatCompressionInfo } from './turn.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -119,6 +121,13 @@ describe('GeminiChat', async () => { getTool: vi.fn(), }), getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), + getChatCompression: vi.fn().mockReturnValue(undefined), + getHookSystem: vi.fn().mockReturnValue(undefined), + getDebugLogger: vi + .fn() + .mockReturnValue({ debug: vi.fn(), warn: vi.fn(), info: vi.fn() }), + getApprovalMode: vi.fn().mockReturnValue('default'), + getFileReadCache: vi.fn().mockReturnValue({ clear: vi.fn() }), } as unknown as Config; // Disable 429 simulation for tests @@ -1019,6 +1028,223 @@ describe('GeminiChat', async () => { }); }); + describe('auto-compression integration', () => { + function makeStreamResponse(text = 'ok') { + return (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text }], role: 'model' }, + finishReason: 'STOP', + index: 0, + safetyRatings: [], + }, + ], + text: () => text, + } as unknown as GenerateContentResponse; + })(); + } + + it('releases the send-lock when auto-compression throws (no deadlock)', async () => { + const compressSpy = vi + .spyOn(ChatCompressionService.prototype, 'compress') + .mockRejectedValueOnce(new Error('compression API down')); + + // First send: compression rejects, error propagates to caller. The + // streamDoneResolver must run so this.sendPromise resolves; otherwise + // every subsequent send blocks forever. + await expect( + chat.sendMessageStream( + 'test-model', + { message: 'first' }, + 'prompt-id-deadlock-1', + ), + ).rejects.toThrow('compression API down'); + + // Second send: compress returns NOOP, request goes through. If the + // lock leaked, this await would never resolve. + compressSpy.mockResolvedValueOnce({ + newHistory: null, + info: { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }, + }); + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + makeStreamResponse('second response'), + ); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'second' }, + 'prompt-id-deadlock-2', + ); + for await (const _ of stream) { + /* consume */ + } + + expect(compressSpy).toHaveBeenCalledTimes(2); + }); + + it('seeds inherited token count via setLastPromptTokenCount', async () => { + const subagentChat = new GeminiChat(mockConfig, config, [ + { role: 'user', parts: [{ text: 'inherited' }] }, + { role: 'model', parts: [{ text: 'inherited reply' }] }, + ]); + subagentChat.setLastPromptTokenCount(123_456); + expect(subagentChat.getLastPromptTokenCount()).toBe(123_456); + + // The compression service receives the seeded count, so the threshold + // check sees the inherited size — not the constructor default of 0. + const compressSpy = vi + .spyOn(ChatCompressionService.prototype, 'compress') + .mockResolvedValue({ + newHistory: null, + info: { + originalTokenCount: 123_456, + newTokenCount: 123_456, + compressionStatus: CompressionStatus.NOOP, + }, + }); + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + makeStreamResponse(), + ); + + const stream = await subagentChat.sendMessageStream( + 'test-model', + { message: 'go' }, + 'prompt-id-seed', + ); + for await (const _ of stream) { + /* consume */ + } + + expect(compressSpy).toHaveBeenCalledTimes(1); + expect(compressSpy.mock.calls[0][1].originalTokenCount).toBe(123_456); + }); + + it('yields a COMPRESSED stream event as the first event after auto-compression succeeds', async () => { + const compressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'summary' }] }, + { role: 'model', parts: [{ text: 'ok' }] }, + ]; + vi.spyOn( + ChatCompressionService.prototype, + 'compress', + ).mockResolvedValueOnce({ + newHistory: compressedHistory, + info: { + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, + }, + }); + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + makeStreamResponse('answer'), + ); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'go' }, + 'prompt-id-yield-compressed', + ); + const events: Array<{ type: StreamEventType }> = []; + for await (const event of stream) { + events.push(event as { type: StreamEventType }); + } + + expect(events.length).toBeGreaterThan(0); + expect(events[0].type).toBe(StreamEventType.COMPRESSED); + expect( + (events[0] as { type: StreamEventType; info: ChatCompressionInfo }).info + .compressionStatus, + ).toBe(CompressionStatus.COMPRESSED); + expect( + (events[0] as { type: StreamEventType; info: ChatCompressionInfo }).info + .newTokenCount, + ).toBe(200); + }); + + it('clears hasFailedCompressionAttempt after a forced successful compression', async () => { + const compressSpy = vi.spyOn( + ChatCompressionService.prototype, + 'compress', + ); + + // Step 1: auto-compression fails — latch is set on the chat. + compressSpy.mockResolvedValueOnce({ + newHistory: null, + info: { + originalTokenCount: 100_000, + newTokenCount: 100_000, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }); + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + makeStreamResponse(), + ); + const stream1 = await chat.sendMessageStream( + 'test-model', + { message: 'first' }, + 'prompt-latch-1', + ); + for await (const _ of stream1) { + /* consume */ + } + // Latch passed to service was false on this attempt; service marks it + // failed and tryCompress flips the chat's flag to true. + expect(compressSpy.mock.calls[0][1].hasFailedCompressionAttempt).toBe( + false, + ); + + // Step 2: a forced /compress succeeds. After this, the latch must + // be cleared so future auto-compressions are not suppressed. + compressSpy.mockResolvedValueOnce({ + newHistory: [ + { role: 'user', parts: [{ text: 'summary' }] }, + { role: 'model', parts: [{ text: 'ack' }] }, + ], + info: { + originalTokenCount: 100_000, + newTokenCount: 30_000, + compressionStatus: CompressionStatus.COMPRESSED, + }, + }); + await chat.tryCompress('prompt-latch-force', 'test-model', true); + // tryCompress was called with force=true, so the service got latch=true + // (the gate is `hasFailedCompressionAttempt && !force`, force overrides). + expect(compressSpy.mock.calls[1][1].hasFailedCompressionAttempt).toBe( + true, + ); + + // Step 3: next auto-compression sees the cleared latch. + compressSpy.mockResolvedValueOnce({ + newHistory: null, + info: { + originalTokenCount: 30_000, + newTokenCount: 30_000, + compressionStatus: CompressionStatus.NOOP, + }, + }); + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + makeStreamResponse(), + ); + const stream2 = await chat.sendMessageStream( + 'test-model', + { message: 'second' }, + 'prompt-latch-2', + ); + for await (const _ of stream2) { + /* consume */ + } + expect(compressSpy.mock.calls[2][1].hasFailedCompressionAttempt).toBe( + false, + ); + }); + }); + describe('addHistory', () => { it('should add a new content item to the history', () => { const newContent: Content = { @@ -2425,4 +2651,132 @@ describe('GeminiChat', async () => { expect(mergedText).toBe('BCD'); }); }); + + // Compression logic is tested in chatCompressionService.test.ts; this + // suite covers per-chat state on GeminiChat: hasFailedCompressionAttempt + // stickiness, token-count mutation, history replacement, and conditional + // telemetry mirroring. + describe('tryCompress (per-chat state)', () => { + const userMsg = (text: string) => ({ + role: 'user' as const, + parts: [{ text }], + }); + const modelMsg = (text: string) => ({ + role: 'model' as const, + parts: [{ text }], + }); + + /** + * Mock a successful compression: the service returns COMPRESSED with a + * fresh history. We don't go through the real + * `config.getContentGenerator().generateContent` path here — the service + * is mocked at the boundary. + */ + function mockCompressionService( + result: 'compressed' | 'failed-inflated' | 'noop', + ) { + const compressSpy = vi.spyOn( + ChatCompressionService.prototype, + 'compress', + ); + if (result === 'compressed') { + compressSpy.mockResolvedValue({ + newHistory: [userMsg('summary'), modelMsg('ok'), userMsg('latest')], + info: { + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, + }, + }); + } else if (result === 'failed-inflated') { + compressSpy.mockResolvedValue({ + newHistory: null, + info: { + originalTokenCount: 1000, + newTokenCount: 1100, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }); + } else { + compressSpy.mockResolvedValue({ + newHistory: null, + info: { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }, + }); + } + return compressSpy; + } + + it('replaces history and updates per-chat lastPromptTokenCount on COMPRESSED', async () => { + mockCompressionService('compressed'); + chat.setHistory([userMsg('a'), modelMsg('b'), userMsg('c')]); + + const info = await chat.tryCompress('p1', 'm1'); + + expect(info.compressionStatus).toBe(CompressionStatus.COMPRESSED); + expect(chat.getHistory()).toHaveLength(3); + expect(chat.getHistory()[0]).toEqual(userMsg('summary')); + expect(chat.getLastPromptTokenCount()).toBe(200); + }); + + it('mirrors lastPromptTokenCount to the global telemetry only when wired', async () => { + mockCompressionService('compressed'); + // chat under test was constructed with telemetryService=uiTelemetryService. + await chat.tryCompress('p2', 'm1'); + expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledWith( + 200, + ); + + // A subagent-style chat with no telemetryService must NOT touch the + // global singleton (per the constructor docstring; per-chat counter + // still updates). + const subagentChat = new GeminiChat(mockConfig, config, []); + vi.mocked(uiTelemetryService.setLastPromptTokenCount).mockClear(); + mockCompressionService('compressed'); + const info = await subagentChat.tryCompress('p3', 'm1'); + expect(info.compressionStatus).toBe(CompressionStatus.COMPRESSED); + expect(subagentChat.getLastPromptTokenCount()).toBe(200); + expect(uiTelemetryService.setLastPromptTokenCount).not.toHaveBeenCalled(); + }); + + it('marks hasFailedCompressionAttempt and suppresses subsequent unforced auto-compactions', async () => { + const compressSpy = mockCompressionService('failed-inflated'); + + const first = await chat.tryCompress('p1', 'm1'); + expect(first.compressionStatus).toBe( + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + ); + expect(compressSpy).toHaveBeenCalledTimes(1); + + // The next unforced call should reach the service with + // hasFailedCompressionAttempt=true; the service's threshold check then + // returns NOOP. The important thing here is that GeminiChat actually + // forwards the sticky flag. + compressSpy.mockClear(); + compressSpy.mockResolvedValue({ + newHistory: null, + info: { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }, + }); + await chat.tryCompress('p2', 'm1'); + expect(compressSpy).toHaveBeenCalledTimes(1); + expect(compressSpy.mock.calls[0][1].hasFailedCompressionAttempt).toBe( + true, + ); + }); + + it('forwards force=true to the compression service', async () => { + const compressSpy = mockCompressionService('compressed'); + + await chat.tryCompress('p1', 'm1', true); + expect(compressSpy.mock.calls[0][1].force).toBe(true); + }); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 804c143d8..6249a56db 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -31,11 +31,13 @@ import { logContentRetryFailure, } from '../telemetry/loggers.js'; import { type ChatRecordingService } from '../services/chatRecordingService.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; import { ContentRetryEvent, ContentRetryFailureEvent, } from '../telemetry/types.js'; import type { UiTelemetryService } from '../telemetry/uiTelemetry.js'; +import { type ChatCompressionInfo, CompressionStatus } from './turn.js'; const debugLogger = createDebugLogger('QWEN_CODE_CHAT'); @@ -45,6 +47,11 @@ export enum StreamEventType { /** A signal that a retry is about to happen. The UI should discard any partial * content from the attempt that just failed. */ RETRY = 'retry', + /** Emitted once at the start of the stream when an automatic compression + * pass succeeded. Carries the compression result so callers (the main + * agent UI, subagent loop) can surface it without each call site running + * its own compaction step. */ + COMPRESSED = 'compressed', } export type StreamEvent = @@ -56,7 +63,8 @@ export type StreamEvent = * fresh restart (escalation). The UI should keep the accumulated text * buffer so the continuation appends to it. */ isContinuation?: boolean; - }; + } + | { type: StreamEventType.COMPRESSED; info: ChatCompressionInfo }; /** * Options for retrying due to invalid content from the model. @@ -299,6 +307,22 @@ export class GeminiChat { // model. private sendPromise: Promise = Promise.resolve(); + /** + * Per-chat last-prompt-token-count, populated from `usageMetadata` on each + * model response. Used by the compaction threshold check so that subagents + * (which intentionally don't write to the global telemetry singleton) can + * still make compaction decisions based on their *own* context size. + */ + private lastPromptTokenCount = 0; + + /** + * Per-chat sticky flag. After an unforced compression attempt fails (empty + * summary or inflated token count), automatic compaction is suppressed + * for the remainder of this chat to avoid burning compression API calls + * in a loop. Manual `/compress` still works (it passes `force=true`). + */ + private hasFailedCompressionAttempt = false; + /** * Creates a new GeminiChat instance. * @@ -321,6 +345,96 @@ export class GeminiChat { validateHistory(history); } + /** + * Most recent prompt-token count reported by the model for *this* chat, + * mirroring the value in {@link UiTelemetryService} for the main session. + * Subagent chats have no telemetry service wired but still need a per-chat + * count for compaction decisions, so this is always populated regardless + * of whether the global telemetry is updated. + */ + getLastPromptTokenCount(): number { + return this.lastPromptTokenCount; + } + + /** + * Seed the last-prompt-token-count for chats created with inherited + * history (forks, subagents, speculation). Without this, the auto-compress + * threshold check sees `0` and refuses to compress — so the first API call + * can 400 from oversized history. Callers pass the parent chat's + * `getLastPromptTokenCount()` here. + */ + setLastPromptTokenCount(count: number): void { + this.lastPromptTokenCount = count; + } + + /** + * Attempt to compress this chat's history. + * + * Returns the compression info regardless of outcome. On a successful + * compaction (`COMPRESSED`), this method has already mutated the chat's + * history, recorded the event to `chatRecordingService` (if wired), and + * updated both the per-chat token count and (when wired) the global + * telemetry singleton. + */ + async tryCompress( + promptId: string, + model: string, + force = false, + signal?: AbortSignal, + ): Promise { + const service = new ChatCompressionService(); + const { newHistory, info } = await service.compress(this, { + promptId, + force, + model, + config: this.config, + hasFailedCompressionAttempt: this.hasFailedCompressionAttempt, + originalTokenCount: this.lastPromptTokenCount, + signal, + }); + + if (info.compressionStatus === CompressionStatus.COMPRESSED && newHistory) { + this.chatRecordingService?.recordChatCompression({ + info, + compressedHistory: newHistory, + }); + // Auto-compaction replaces history in place — no env-context refresh + // here. Manual /compress goes through GeminiClient.tryCompressChat, + // which calls startChat() to re-prepend a fresh env snapshot. See + // GeminiClient.sendMessageStream for the rationale behind the split. + this.setHistory(newHistory); + // Compaction summarises away prior full-Read tool results, but the + // FileReadCache still treats those reads as "in this conversation". + // A follow-up Read could then return the file_unchanged placeholder + // pointing at content the model can no longer retrieve from history. + debugLogger.debug('[FILE_READ_CACHE] clear after auto tryCompress'); + this.config.getFileReadCache().clear(); + this.lastPromptTokenCount = info.newTokenCount; + // Mirror to the global singleton only when wired (main session). + // Subagents pass `telemetryService=undefined` to keep their context + // usage out of the main agent's UI counters. + this.telemetryService?.setLastPromptTokenCount(info.newTokenCount); + // Re-enable auto-compaction so a forced /compress recovers a chat + // that an earlier auto-attempt latched off. + this.hasFailedCompressionAttempt = false; + } else if ( + info.compressionStatus === + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT || + info.compressionStatus === + CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY || + info.compressionStatus === + CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR + ) { + // Track failed attempts (only mark as failed if not forced) so we + // stop spending compression-API calls on a chat that can't shrink. + if (!force) { + this.hasFailedCompressionAttempt = true; + } + } + + return info; + } + setSystemInstruction(sysInstr: string) { this.generationConfig.systemInstruction = sysInstr; } @@ -360,6 +474,23 @@ export class GeminiChat { }); this.sendPromise = streamDonePromise; + // The send-lock above is held but the generator's `finally` (which + // resolves it) has not run yet — if `tryCompress` throws, we must + // release the lock here or subsequent sends will block forever at + // `await this.sendPromise`. + let compressionInfo: ChatCompressionInfo; + try { + compressionInfo = await this.tryCompress( + prompt_id, + model, + false, + params.config?.abortSignal, + ); + } catch (error) { + streamDoneResolver!(); + throw error; + } + const userContent = createUserContent(params.message); // Add user content to history ONCE before any attempts. @@ -370,6 +501,20 @@ export class GeminiChat { const self = this; return (async function* () { try { + // Surface a successful auto-compression to the caller as the first + // event in the stream. Failed/skipped compaction attempts are silent. + // Must be inside the try so that a consumer abandoning the stream + // immediately after this event still triggers the finally below; + // otherwise `streamDoneResolver` never fires and the next send hangs. + if ( + compressionInfo.compressionStatus === CompressionStatus.COMPRESSED + ) { + yield { + type: StreamEventType.COMPRESSED, + info: compressionInfo, + }; + } + let lastError: unknown = new Error('Request failed after all retries.'); let rateLimitRetryCount = 0; let invalidStreamRetryCount = 0; @@ -890,8 +1035,14 @@ export class GeminiChat { // Some providers omit total_tokens or return 0 in streaming usage chunks. const lastPromptTokenCount = usageMetadata.totalTokenCount || usageMetadata.promptTokenCount; - if (lastPromptTokenCount && this.telemetryService) { - this.telemetryService.setLastPromptTokenCount(lastPromptTokenCount); + if (lastPromptTokenCount) { + // Always update the per-chat counter so this chat (including + // subagents) can make its own compaction decisions. + this.lastPromptTokenCount = lastPromptTokenCount; + // Mirror to the global telemetry only when wired — subagents + // pass `telemetryService=undefined` to keep their context usage + // out of the main session's UI counters. + this.telemetryService?.setLastPromptTokenCount(lastPromptTokenCount); } if (usageMetadata.cachedContentTokenCount && this.telemetryService) { this.telemetryService.setLastCachedContentTokenCount( diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 92c86f9ed..6626f56c5 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -9,7 +9,7 @@ import type { ServerGeminiToolCallRequestEvent, ServerGeminiErrorEvent, } from './turn.js'; -import { Turn, GeminiEventType } from './turn.js'; +import { CompressionStatus, Turn, GeminiEventType } from './turn.js'; import type { GenerateContentResponse, Part, Content } from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; import type { GeminiChat } from './geminiChat.js'; @@ -845,6 +845,38 @@ describe('Turn', () => { { type: GeminiEventType.Content, value: 'Success' }, ]); }); + + it('bridges a compressed stream event to a ChatCompressed event', async () => { + const compressionInfo = { + originalTokenCount: 1000, + newTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, + }; + const mockResponseStream = (async function* () { + yield { type: StreamEventType.COMPRESSED, info: compressionInfo }; + yield { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'after' }] } }], + }, + }; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run( + 'test-model', + [], + new AbortController().signal, + )) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.ChatCompressed, value: compressionInfo }, + { type: GeminiEventType.Content, value: 'after' }, + ]); + }); }); describe('getDebugResponses', () => { diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index b0971d0b6..34f78bcd0 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -306,6 +306,19 @@ export class Turn { continue; // Skip to the next event in the stream } + // Surface auto-compaction that fired inside chat.sendMessageStream + // as the top-level ChatCompressed event so existing UI handlers stay + // connected. This bridge is the primary path for auto-compaction + // events; manual /compress emits its own ChatCompressed in + // GeminiClient.tryCompressChat. + if (streamEvent.type === 'compressed') { + yield { + type: GeminiEventType.ChatCompressed, + value: streamEvent.info, + }; + continue; + } + // Assuming other events are chunks with a `value` property const resp = streamEvent.value as GenerateContentResponse; if (!resp) continue; // Skip if there's no response body diff --git a/packages/core/src/memory/dream.ts b/packages/core/src/memory/dream.ts index 67a4a6af6..a76e944a2 100644 --- a/packages/core/src/memory/dream.ts +++ b/packages/core/src/memory/dream.ts @@ -23,28 +23,16 @@ export interface AutoMemoryDreamResult { systemMessage?: string; } -async function bumpMetadata(projectRoot: string, now: Date): Promise { - const metadataPath = getAutoMemoryMetadataPath(projectRoot); - try { - const content = await fs.readFile(metadataPath, 'utf-8'); - const metadata = JSON.parse(content) as AutoMemoryMetadata; - metadata.updatedAt = now.toISOString(); - metadata.lastDreamAt = now.toISOString(); - await fs.writeFile( - metadataPath, - `${JSON.stringify(metadata, null, 2)}\n`, - 'utf-8', - ); - } catch { - // Best-effort metadata bump. - } -} - async function runDreamByAgent( projectRoot: string, config: Config, + abortSignal?: AbortSignal, ): Promise { - const result = await planManagedAutoMemoryDreamByAgent(config, projectRoot); + const result = await planManagedAutoMemoryDreamByAgent( + config, + projectRoot, + abortSignal, + ); // Infer which topics were touched from the file paths const touchedTopics = new Set(); @@ -72,6 +60,7 @@ export async function runManagedAutoMemoryDream( projectRoot: string, now = new Date(), config?: Config, + abortSignal?: AbortSignal, ): Promise { await ensureAutoMemoryScaffold(projectRoot, now); const t0 = Date.now(); @@ -82,14 +71,26 @@ export async function runManagedAutoMemoryDream( ); } - const agentResult = await runDreamByAgent(projectRoot, config); + const agentResult = await runDreamByAgent(projectRoot, config, abortSignal); + // Cancel-aware ordering: + // 1. If aborted before this point, return the agent's partial result + // WITHOUT rebuilding the index — index rebuild can be expensive + // and re-running a cancelled dream cycle next time will rebuild + // against the latest topic files anyway. + // 2. If still alive, rebuild the index (informational, powers + // recall) — but only when topics actually changed. + // Scheduler-gating metadata (`lastDreamAt`, `lastDreamSessionId`, + // `lastDreamTouchedTopics`, `lastDreamStatus`) is intentionally NOT + // written here — `MemoryManager.runDream` owns the atomic + // status-flip + metadata-write sequence to close the cancel race + // window where a writeFile finishing concurrently with a cancel + // could persist gating metadata for a record the manager is about + // to mark `'cancelled'`. + if (abortSignal?.aborted) return agentResult; if (agentResult.touchedTopics.length > 0) { - await bumpMetadata(projectRoot, now); await rebuildManagedAutoMemoryIndex(projectRoot); } - await updateDreamMetadataResult(projectRoot, now, agentResult.touchedTopics); - logMemoryDream( config, new MemoryDreamEvent({ diff --git a/packages/core/src/memory/dreamAgentPlanner.test.ts b/packages/core/src/memory/dreamAgentPlanner.test.ts index a84aac180..4a0f48583 100644 --- a/packages/core/src/memory/dreamAgentPlanner.test.ts +++ b/packages/core/src/memory/dreamAgentPlanner.test.ts @@ -136,16 +136,25 @@ describe('dreamAgentPlanner', () => { ).rejects.toThrow('Model timed out'); }); - it('returns cancelled result without throwing', async () => { + it('throws when the agent terminates as cancelled', async () => { + // runForkedAgent maps AgentTerminateMode.CANCELLED to a resolved + // `{status: 'cancelled'}` rather than a rejection. Without + // re-throwing here, `runDreamByAgent` and downstream callers would + // treat an aborted run as a normal completion — bumping + // `lastDreamAt` metadata and overwriting a user-cancelled task + // record with `'completed'`. The throw lets the manager's existing + // catch path (which checks `signal.aborted && status === 'cancelled'`) + // do the right thing. const mockResult: ForkedAgentResult = { status: 'cancelled', + terminateReason: 'CANCELLED', filesTouched: [], }; vi.mocked(runForkedAgent).mockResolvedValue(mockResult); - const result = await planManagedAutoMemoryDreamByAgent(config, projectRoot); - expect(result.status).toBe('cancelled'); - expect(result.filesTouched).toHaveLength(0); + await expect( + planManagedAutoMemoryDreamByAgent(config, projectRoot), + ).rejects.toThrow(/cancelled/i); }); }); diff --git a/packages/core/src/memory/dreamAgentPlanner.ts b/packages/core/src/memory/dreamAgentPlanner.ts index 790d91cc4..5496eb1e6 100644 --- a/packages/core/src/memory/dreamAgentPlanner.ts +++ b/packages/core/src/memory/dreamAgentPlanner.ts @@ -227,6 +227,7 @@ export function buildConsolidationTaskPrompt( export async function planManagedAutoMemoryDreamByAgent( config: Config, projectRoot: string, + abortSignal?: AbortSignal, ): Promise { const memoryRoot = getAutoMemoryRoot(projectRoot); const transcriptDir = getTranscriptDir(projectRoot); @@ -247,11 +248,23 @@ export async function planManagedAutoMemoryDreamByAgent( ToolNames.WRITE_FILE, ToolNames.EDIT, ], + abortSignal, }); if (result.status === 'failed') { throw new Error(result.terminateReason || 'Dream agent failed'); } + if (result.status === 'cancelled') { + // runForkedAgent maps AgentTerminateMode.CANCELLED → status 'cancelled' + // (resolves rather than rejects). Throw here so callers up the stack + // unwind via their catch paths instead of silently treating an + // aborted dream as a normal completion (which would overwrite the + // user-cancelled record with 'completed' + bump dream metadata). + throw new Error( + result.terminateReason || 'Dream agent cancelled before completion', + ); + } + return result; } diff --git a/packages/core/src/memory/manager.test.ts b/packages/core/src/memory/manager.test.ts index 4860bdaae..74a20a361 100644 --- a/packages/core/src/memory/manager.test.ts +++ b/packages/core/src/memory/manager.test.ts @@ -260,6 +260,77 @@ describe('MemoryManager', () => { }); }); + // ─── subscribe() filter ────────────────────────────────────────────────── + + describe('subscribe() taskType filter', () => { + // The filter exists so high-frequency consumers (the bg-tasks UI + // hook, only rendering dream entries) can skip the per-extract + // notify entirely. Pin the routing both ways: filtered subscribers + // must NOT fire on unrelated transitions, and unfiltered + // subscribers must continue to fire on everything. + it('routes notifies to type-filtered subscribers only when taskType matches', async () => { + vi.mocked(runAutoMemoryExtract).mockResolvedValue({ + touchedTopics: [], + cursor: { sessionId: 'sess', updatedAt: new Date().toISOString() }, + }); + const mgr = new MemoryManager(); + const dreamFilteredFires = vi.fn(); + const extractFilteredFires = vi.fn(); + const unfilteredFires = vi.fn(); + mgr.subscribe(dreamFilteredFires, { taskType: 'dream' }); + mgr.subscribe(extractFilteredFires, { taskType: 'extract' }); + mgr.subscribe(unfilteredFires); + + await mgr.scheduleExtract({ + projectRoot: '/p', + sessionId: 'sess', + history: [{ role: 'user', parts: [{ text: 'hi' }] }], + }); + await mgr.drain(); + + // Extract scheduling fires storeWith (1) + completion update (1) = 2 notifies. + // Dream-filtered subscriber must NOT see them. + expect(dreamFilteredFires).not.toHaveBeenCalled(); + // Both extract-filtered and unfiltered subscribers must see them. + expect(extractFilteredFires.mock.calls.length).toBeGreaterThanOrEqual(1); + expect(unfilteredFires.mock.calls.length).toBeGreaterThanOrEqual(1); + }); + + it('returns an unsubscribe function that drops the filtered listener even when later notifies fire', async () => { + // Verify the unsubscribe actually severs the listener — the + // earlier version of this test only asserted "not called yet" + // without ever firing a notify, so the listener could have + // remained attached and the test would still pass. + vi.mocked(runAutoMemoryExtract).mockResolvedValue({ + touchedTopics: [], + cursor: { sessionId: 'sess', updatedAt: new Date().toISOString() }, + }); + const mgr = new MemoryManager(); + const fires = vi.fn(); + const unsubscribe = mgr.subscribe(fires, { taskType: 'extract' }); + + // First extract should fire the listener (storeWith + completion update). + await mgr.scheduleExtract({ + projectRoot: '/p', + sessionId: 'sess', + history: [{ role: 'user', parts: [{ text: 'hi' }] }], + }); + await mgr.drain(); + const firesBeforeUnsubscribe = fires.mock.calls.length; + expect(firesBeforeUnsubscribe).toBeGreaterThanOrEqual(1); + + // After unsubscribe, a second extract must not increment the count. + unsubscribe(); + await mgr.scheduleExtract({ + projectRoot: '/p', + sessionId: 'sess-2', + history: [{ role: 'user', parts: [{ text: 'hi again' }] }], + }); + await mgr.drain(); + expect(fires.mock.calls.length).toBe(firesBeforeUnsubscribe); + }); + }); + // ─── scheduleDream() ───────────────────────────────────────────────────── describe('scheduleDream()', () => { @@ -314,15 +385,35 @@ describe('MemoryManager', () => { expect(result).toEqual({ status: 'skipped', skippedReason: 'disabled' }); }); + it('skips when params.config is omitted entirely', async () => { + // Without config, runManagedAutoMemoryDream throws — surfacing + // a noisy failed entry in the bg-tasks dialog. The early skip + // converts the omitted-config case to the same disabled-skip + // path so callers can't accidentally produce visible failures + // by leaving config out (the type allows it for test ergonomics). + const mgr = new MemoryManager(); + const result = await mgr.scheduleDream({ + projectRoot, + sessionId: 'sess-no-config', + // config intentionally omitted + now: new Date('2026-04-02T10:00:00.000Z'), + }); + expect(result).toEqual({ status: 'skipped', skippedReason: 'disabled' }); + // Crucially — no record was stored for this skip. + expect(mgr.listTasksByType('dream', projectRoot)).toEqual([]); + }); + it('skips when called again in the same session', async () => { const scanner = vi .fn() .mockResolvedValue(['sess-0', 'sess-1', 'sess-2', 'sess-3', 'sess-4']); const mgr = new MemoryManager(scanner); + const config = makeMockConfig(); const first = await mgr.scheduleDream({ projectRoot, sessionId: 'sess-x', + config, now: new Date('2026-04-01T10:00:00.000Z'), minHoursBetweenDreams: 0, minSessionsBetweenDreams: 1, @@ -333,6 +424,7 @@ describe('MemoryManager', () => { const second = await mgr.scheduleDream({ projectRoot, sessionId: 'sess-x', + config, now: new Date('2026-04-01T11:00:00.000Z'), minHoursBetweenDreams: 0, minSessionsBetweenDreams: 1, @@ -365,6 +457,7 @@ describe('MemoryManager', () => { const result = await mgr.scheduleDream({ projectRoot, sessionId: 'sess-new', + config: makeMockConfig(), now: new Date('2026-04-01T10:00:00.000Z'), minHoursBetweenDreams: 24, minSessionsBetweenDreams: 1, @@ -380,6 +473,7 @@ describe('MemoryManager', () => { const result = await mgr.scheduleDream({ projectRoot, sessionId: 'sess-new', + config: makeMockConfig(), now: new Date('2026-04-01T10:00:00.000Z'), minHoursBetweenDreams: 0, minSessionsBetweenDreams: 5, @@ -401,6 +495,7 @@ describe('MemoryManager', () => { const result = await mgr.scheduleDream({ projectRoot, sessionId: 'sess-x', + config: makeMockConfig(), now: new Date('2026-04-01T10:00:00.000Z'), minHoursBetweenDreams: 0, minSessionsBetweenDreams: 3, @@ -425,6 +520,198 @@ describe('MemoryManager', () => { }); }); + // ─── cancelTask() ──────────────────────────────────────────────────────── + + describe('cancelTask()', () => { + let tempDir: string; + let projectRoot: string; + + beforeEach(async () => { + vi.resetAllMocks(); + process.env['QWEN_CODE_MEMORY_LOCAL'] = '1'; + clearAutoMemoryRootCache(); + tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'mgr-cancel-')); + projectRoot = path.join(tempDir, 'project'); + await fs.mkdir(projectRoot, { recursive: true }); + await ensureAutoMemoryScaffold( + projectRoot, + new Date('2026-04-01T00:00:00.000Z'), + ); + }); + + afterEach(async () => { + delete process.env['QWEN_CODE_MEMORY_LOCAL']; + clearAutoMemoryRootCache(); + await fs.rm(tempDir, { recursive: true, force: true }); + }); + + it('aborts the dream fork agent and marks the record cancelled', async () => { + // The fork's abort signal is captured here so the test can assert + // both the status flip AND the actual signal propagation — only + // the latter guarantees runForkedAgent will unwind. + let capturedSignal: AbortSignal | undefined; + let resolveDreamStarted!: () => void; + const dreamStarted = new Promise((r) => { + resolveDreamStarted = r; + }); + vi.mocked(runManagedAutoMemoryDream).mockImplementation( + async (_root, _now, _config, signal) => { + capturedSignal = signal; + resolveDreamStarted(); + // Simulate a long-running dream that respects the signal. + await new Promise((_, reject) => { + signal?.addEventListener('abort', () => + reject(new Error('aborted')), + ); + }); + return { + touchedTopics: [], + dedupedEntries: 0, + systemMessage: undefined, + }; + }, + ); + + const mgr = new MemoryManager(async () => [ + 'sess-0', + 'sess-1', + 'sess-2', + 'sess-3', + 'sess-4', + ]); + const config = makeMockConfig(); + const result = await mgr.scheduleDream({ + projectRoot, + sessionId: 'sess-x', + config, + now: new Date('2026-04-02T10:00:00.000Z'), + }); + expect(result.status).toBe('scheduled'); + const taskId = result.taskId!; + + // Wait for the fork to actually enter — scheduleDream returns + // before lock acquisition + the fork-agent invocation actually + // run. Cancelling before the fork enters would race the abort + // signal capture and produce a flaky undefined. + await dreamStarted; + + // Cancel must succeed and synchronously flip status; the fork's + // unwind happens later via the abort signal. + const cancelled = mgr.cancelTask(taskId); + expect(cancelled).toBe(true); + expect(mgr.getTask(taskId)?.status).toBe('cancelled'); + expect(capturedSignal?.aborted).toBe(true); + + // Drain so the fork-agent rejection lands and runDream's catch + // path runs — the user-cancel guard must NOT overwrite to + // 'failed'. (Without the guard, the rejected promise sets the + // record to failed with error="aborted".) + await mgr.drain({ timeoutMs: 1000 }); + expect(mgr.getTask(taskId)?.status).toBe('cancelled'); + }); + + it('keeps the record cancelled even when runManagedAutoMemoryDream resolves successfully after abort', async () => { + // The realistic abort path: runForkedAgent maps + // AgentTerminateMode.CANCELLED to a resolved `{status: 'cancelled'}` + // rather than a rejection. dreamAgentPlanner is supposed to + // rethrow that case, but the manager carries an additional + // signal.aborted check after the await as defense in depth. + // This test simulates the "resolved despite cancel" scenario by + // having the mock RESOLVE on abort instead of rejecting — without + // the guard, runDream's success path would overwrite the + // user-cancelled record to 'completed' and bump dream metadata + // for an aborted run. + let resolveStarted!: () => void; + const started = new Promise((r) => { + resolveStarted = r; + }); + vi.mocked(runManagedAutoMemoryDream).mockImplementation( + async (_root, _now, _config, signal) => { + resolveStarted(); + await new Promise((resolve) => { + signal?.addEventListener('abort', () => resolve()); + }); + return { + touchedTopics: ['user', 'project'], + dedupedEntries: 0, + systemMessage: 'Managed auto-memory dream completed.', + }; + }, + ); + + const mgr = new MemoryManager(async () => [ + 'sess-0', + 'sess-1', + 'sess-2', + 'sess-3', + 'sess-4', + ]); + const config = makeMockConfig(); + const result = await mgr.scheduleDream({ + projectRoot, + sessionId: 'sess-x', + config, + now: new Date('2026-04-02T10:00:00.000Z'), + }); + const taskId = result.taskId!; + await started; + mgr.cancelTask(taskId); + await mgr.drain({ timeoutMs: 1000 }); + + expect(mgr.getTask(taskId)?.status).toBe('cancelled'); + // Metadata write must NOT have happened — lastDreamAt should + // still be the scaffold's initial value, not the cancelled-run's + // `now`. (Bumping it would suppress the next legitimate dream.) + const metaRaw = await fs.readFile( + getAutoMemoryMetadataPath(projectRoot), + 'utf-8', + ); + const meta = JSON.parse(metaRaw) as { + lastDreamAt?: string; + lastDreamSessionId?: string; + }; + expect(meta.lastDreamAt).not.toBe('2026-04-02T10:00:00.000Z'); + expect(meta.lastDreamSessionId).not.toBe('sess-x'); + }); + + it('returns false for unknown task ids', async () => { + const mgr = new MemoryManager(); + expect(mgr.cancelTask('does-not-exist')).toBe(false); + }); + + it('returns false for an already-completed dream', async () => { + // The dream's natural completion path runs first, marks the + // record terminal; a subsequent cancel attempt must no-op rather + // than overwrite the recorded outcome (would erase touchedTopics + // metadata the user just saw via memory_saved toast). + vi.mocked(runManagedAutoMemoryDream).mockResolvedValue({ + touchedTopics: [], + dedupedEntries: 0, + systemMessage: undefined, + }); + const mgr = new MemoryManager(async () => [ + 'sess-0', + 'sess-1', + 'sess-2', + 'sess-3', + 'sess-4', + ]); + const config = makeMockConfig(); + const result = await mgr.scheduleDream({ + projectRoot, + sessionId: 'sess-x', + config, + now: new Date('2026-04-02T10:00:00.000Z'), + }); + const taskId = result.taskId!; + // Drain so the dream completes naturally. + await mgr.drain({ timeoutMs: 1000 }); + expect(mgr.getTask(taskId)?.status).toBe('completed'); + expect(mgr.cancelTask(taskId)).toBe(false); + expect(mgr.getTask(taskId)?.status).toBe('completed'); + }); + }); + // ─── resetExtractStateForTests() ───────────────────────────────────────── describe('resetExtractStateForTests()', () => { diff --git a/packages/core/src/memory/manager.ts b/packages/core/src/memory/manager.ts index 86924893c..091edf77e 100644 --- a/packages/core/src/memory/manager.ts +++ b/packages/core/src/memory/manager.ts @@ -39,7 +39,13 @@ import { randomUUID } from 'node:crypto'; import type { Content, Part } from '@google/genai'; import type { Config } from '../config/config.js'; import { Storage } from '../config/storage.js'; -import { logMemoryExtract, MemoryExtractEvent } from '../telemetry/index.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; +import { + logMemoryDream, + logMemoryExtract, + MemoryDreamEvent, + MemoryExtractEvent, +} from '../telemetry/index.js'; import { isAutoMemPath } from './paths.js'; import { getAutoMemoryConsolidationLockPath, @@ -67,6 +73,8 @@ import { writeDreamManualRunToMetadata } from './dream.js'; import { buildConsolidationTaskPrompt } from './dreamAgentPlanner.js'; import type { AutoMemoryMetadata } from './types.js'; +const debugLogger = createDebugLogger('AUTO_MEMORY_MANAGER'); + // ─── Re-export public types consumed by callers ─────────────────────────────── export type { @@ -87,6 +95,7 @@ export type MemoryTaskStatus = | 'running' | 'completed' | 'failed' + | 'cancelled' | 'skipped'; export interface MemoryTaskRecord { @@ -346,7 +355,17 @@ export class MemoryManager { // ── Task records ──────────────────────────────────────────────────────────── private readonly tasks = new Map(); // ── Subscribers (useSyncExternalStore / custom listeners) ──────────────── + // Subscribers without a taskType filter receive every notify; those + // with a filter receive only notifies whose changed record matches + // (extract OR dream). Filtered subscribers exist so high-frequency + // consumers (e.g. the bg-tasks UI hook, which only cares about + // dream) can skip the per-extract O(n) work that would otherwise + // run on every UserQuery. private readonly subscribers = new Set<() => void>(); + private readonly subscribersByType = new Map< + 'extract' | 'dream', + Set<() => void> + >(); // ── In-flight promises (for drain) ────────────────────────────────────────── private readonly inFlight = new Map>(); @@ -361,6 +380,22 @@ export class MemoryManager { // ── Dream scheduling state ─────────────────────────────────────────────────── private readonly dreamInFlightByKey = new Map(); private readonly dreamLastSessionScanAt = new Map(); + // AbortControllers for in-flight dream tasks, keyed by record id. + // cancelTask() looks up the controller, aborts it (the abort signal + // propagates into runForkedAgent), and marks the record cancelled. + // The runDream finally block clears the entry on settle. + private readonly dreamAbortControllers = new Map(); + // Set to true when releaseDreamLock() throws (e.g., Windows EPERM, + // ENOENT race, disk full). The lock file is then left on disk and + // dreamLockExists() sees a fresh-mtime lock owned by a still-alive + // PID (us!), suppressing every subsequent scheduleDream() call as + // `{status: 'skipped', skippedReason: 'locked'}` — invisible to the + // user once the surfacing UI just shows "Lock release failed" without + // re-firing. Setting this flag tells the next scheduleDream() to + // force-clean the leaked lock file before the existence check, so + // scheduling resumes within the same session instead of waiting for + // next session start's staleness sweep. + private dreamLockReleaseFailed = false; private readonly sessionScanner: SessionScannerFn; constructor(sessionScanner: SessionScannerFn = defaultSessionScanner) { @@ -372,14 +407,49 @@ export class MemoryManager { * Register a listener that is called whenever any task record changes. * Compatible with React’s `useSyncExternalStore`. * Returns an unsubscribe function. + * + * Pass `{ taskType: 'dream' }` (or `'extract'`) to receive only + * notifies whose changed record matches that type. Filtered + * subscribers skip the wakeup entirely for unrelated transitions — + * the dream-only UI hook uses this to avoid doing O(n) signature + * work on every per-UserQuery extract notify. */ - subscribe(listener: () => void): () => void { + subscribe( + listener: () => void, + opts?: { taskType?: 'extract' | 'dream' }, + ): () => void { + if (opts?.taskType) { + const type = opts.taskType; + let set = this.subscribersByType.get(type); + if (!set) { + set = new Set(); + this.subscribersByType.set(type, set); + } + set.add(listener); + return () => { + set!.delete(listener); + // Drop the Map entry when the per-type bucket is empty so the + // long-lived MemoryManager doesn't accumulate empty Sets across + // repeated subscribe/unsubscribe cycles (e.g. React mount / + // unmount in the bg-tasks UI hook). + if (set!.size === 0) this.subscribersByType.delete(type); + }; + } this.subscribers.add(listener); return () => this.subscribers.delete(listener); } - private notify(): void { + /** + * Notify subscribers. Pass the changed task's type so type-filtered + * subscribers can be reached too; the unfiltered subscriber set + * always receives the wakeup either way. + */ + private notify(taskType?: 'extract' | 'dream'): void { for (const fn of this.subscribers) fn(); + if (taskType) { + const typed = this.subscribersByType.get(taskType); + if (typed) for (const fn of typed) fn(); + } } /** Update a record and notify subscribers. */ @@ -390,7 +460,7 @@ export class MemoryManager { >, ): void { updateRecord(record, patch); - this.notify(); + this.notify(record.taskType); } /** @@ -399,7 +469,7 @@ export class MemoryManager { */ private store(record: MemoryTaskRecord): void { this.tasks.set(record.id, record); - this.notify(); + this.notify(record.taskType); } /** @@ -414,7 +484,7 @@ export class MemoryManager { ): void { updateRecord(record, patch); this.tasks.set(record.id, record); - this.notify(); + this.notify(record.taskType); } // ─── Task record query ──────────────────────────────────────────────────────── @@ -656,7 +726,12 @@ export class MemoryManager { async scheduleDream( params: ScheduleDreamParams, ): Promise { - if (params.config && !params.config.getManagedAutoDreamEnabled()) { + // `params.config` is optional only because some test paths omit it; + // production callers always pass it. Without a config the + // fork-agent execution can't start (`runManagedAutoMemoryDream` + // throws). Skip early so a missing-config call doesn't surface a + // failed dream entry in the bg-tasks dialog. + if (!params.config || !params.config.getManagedAutoDreamEnabled()) { return { status: 'skipped', skippedReason: 'disabled' }; } @@ -700,6 +775,22 @@ export class MemoryManager { return { status: 'skipped', skippedReason: 'min_sessions' }; } + // If the previous dream's release failed (lockReleaseError surfaced + // on the dialog), the lock file is still on disk and dreamLockExists() + // would silently suppress every subsequent dream until next process + // start. Force-clean it here so the same session recovers. + if (this.dreamLockReleaseFailed) { + await fs + .rm(getAutoMemoryConsolidationLockPath(params.projectRoot), { + force: true, + }) + .catch(() => { + // Best-effort recovery — if even the forced rm fails (truly + // unrecoverable filesystem state), fall through and let the + // existence check below report 'locked' as before. + }); + this.dreamLockReleaseFailed = false; + } if (await dreamLockExists(params.projectRoot)) { return { status: 'skipped', skippedReason: 'locked' }; } @@ -720,26 +811,103 @@ export class MemoryManager { params.projectRoot, params.sessionId, ); + // Register the AbortController BEFORE storeWith. storeWith fires + // a notify which can synchronously call cancelTask via subscribers + // (e.g. a UI listener). If the controller isn't in + // `dreamAbortControllers` by then, cancelTask falls into the + // missing-controller defensive warn-and-return-false path and the + // model gets a phantom failure on a brand-new dream. Registering + // first means any reentrant cancel sees a complete state. + const abortController = new AbortController(); + this.dreamAbortControllers.set(record.id, abortController); + this.dreamInFlightByKey.set(dedupeKey, record.id); this.storeWith(record, { status: 'running', + // Set the initial progressText so the dialog's Progress section + // has something to show during the in-flight window — fork-agent + // execution exposes no per-turn callback today, so without this + // the section stays empty until completion. + progressText: 'Scheduled managed auto-memory dream.', metadata: { sessionCount: sessionIds.length }, }); - this.dreamInFlightByKey.set(dedupeKey, record.id); const promise = this.track( record.id, - this.runDream(record, dedupeKey, params, now), + this.runDream(record, dedupeKey, params, now, abortController.signal), ); return { status: 'scheduled', taskId: record.id, promise }; } + /** + * Look up a single task record by id. Used by `task_stop` and other + * cross-cutting consumers that have a task id but no project root. + */ + getTask(taskId: string): MemoryTaskRecord | undefined { + return this.tasks.get(taskId); + } + + /** + * Cancel a running dream task. Aborts the dream's fork agent (the + * abort signal threads through `runForkedAgent`), marks the record + * cancelled immediately so the UI reflects user intent, and lets the + * existing `runDream` finally block release the consolidation lock + * via the natural error propagation path. + * + * Returns true if a running task was aborted, false if the task is + * unknown / already terminal / not a dream. Currently only dream + * tasks support cancellation — extract is short-lived and runs + * synchronously through the request loop; cancelling it would + * interfere with the user's own turn. + */ + cancelTask(taskId: string): boolean { + const record = this.tasks.get(taskId); + if (!record) return false; + if (record.taskType !== 'dream') return false; + if (record.status !== 'running') return false; + + // The AbortController is registered synchronously alongside the + // status='running' transition in scheduleDream and only cleared in + // runDream's finally block (which only runs after a terminal + // status transition has already happened). So under normal flow + // an entry that is `running` MUST have a controller. Treat the + // missing-controller case as a contract violation: don't flip + // status (a cancelled record without an aborted fork would leak + // the consolidation lock until the agent finishes naturally) and + // return false so the caller knows the abort didn't take. Log at + // warn level so the inconsistency is observable in debug bundles + // — silent failure here would leave a runaway dream burning tokens + // with no signal to the user or to telemetry. + const ac = this.dreamAbortControllers.get(taskId); + if (!ac) { + debugLogger.warn( + `cancelTask: AbortController missing for running dream task ${taskId}; ` + + `not flipping status. This indicates a logic bug — the controller ` + + `should have been registered in scheduleDream and only cleared ` + + `after a terminal status transition.`, + ); + return false; + } + + // Mark cancelled BEFORE aborting so the runDream catch path can + // detect the user-cancel intent (signal.aborted + status already + // 'cancelled') and avoid overwriting with a generic 'failed'. + this.update(record, { + status: 'cancelled', + progressText: 'Cancelled by user.', + }); + ac.abort(); + return true; + } + private async runDream( record: MemoryTaskRecord, dedupeKey: string, params: ScheduleDreamParams, now: Date, + abortSignal: AbortSignal, ): Promise { + const dreamStartMs = Date.now(); try { try { await acquireDreamLock(params.projectRoot); @@ -761,13 +929,26 @@ export class MemoryManager { params.projectRoot, now, params.config, + abortSignal, ); - const nextMetadata = await readDreamMetadata(params.projectRoot); - nextMetadata.lastDreamAt = now.toISOString(); - nextMetadata.lastDreamSessionId = params.sessionId; - nextMetadata.updatedAt = now.toISOString(); - await writeDreamMetadata(params.projectRoot, nextMetadata); + // Defense-in-depth: runForkedAgent maps cancelled fork-agents + // to a resolved `{status: 'cancelled'}` rather than a rejection. + // dreamAgentPlanner now rethrows that case so the catch path + // below handles it, but if anything in the call chain ever + // forgets to propagate, this guard prevents the success path + // from clobbering the user-cancelled record with 'completed' + // and bumping dream metadata for an aborted run. + if (abortSignal.aborted) { + return record; + } + // Atomic-from-cancel sequence: flip status='completed' BEFORE + // any scheduler-gating metadata write. Once status is no + // longer 'running', cancelTask refuses, so the writeFile that + // follows can't race a flip-to-cancelled. The cancel-raced- + // status-update branch below covers the remaining window + // (cancel landed between the pre-update check and the + // synchronous update). this.update(record, { status: 'completed', progressText: @@ -778,16 +959,120 @@ export class MemoryManager { lastDreamAt: now.toISOString(), }, }); + if (abortSignal.aborted) { + // Defense-in-depth: unreachable today (no `await` between + // the pre-update check and the synchronous update above, + // so JS's single-threaded execution prevents + // `signal.aborted` from transitioning between them — a + // cancelTask landing inside the storeWith notify would + // already have flipped status, and our update would have + // raced ahead of it to 'completed'). Kept against a future + // refactor that introduces an `await` between the two + // checks. Preserves the touched-topic metadata on the + // restored cancelled record so the user can still tell + // memory files were modified before the abort took. + this.update(record, { + status: 'cancelled', + progressText: 'Cancelled after memory changes.', + metadata: { + touchedTopics: result.touchedTopics, + dedupedEntries: result.dedupedEntries, + }, + }); + return record; + } + // Status is now 'completed'; cancelTask will refuse from + // here on out. Safe to write scheduler-gating metadata + // without a race window. + // + // Wrap the read/write in a try/catch — pre-PR `bumpMetadata` + // in dream.ts swallowed errors as best-effort; without this + // wrap a transient ENOENT / EPERM on the metadata file would + // propagate to the outer catch and overwrite a + // legitimately-completed dream with `'failed'`. The dream + // already did its work (touched files are on disk and + // visible). Trade-off: the next dream cycle won't see a + // bumped lastDreamAt and may re-fire — same trade as the + // original best-effort behavior. + try { + const nextMetadata = await readDreamMetadata(params.projectRoot); + nextMetadata.lastDreamAt = now.toISOString(); + nextMetadata.lastDreamSessionId = params.sessionId; + nextMetadata.updatedAt = now.toISOString(); + nextMetadata.lastDreamTouchedTopics = result.touchedTopics; + nextMetadata.lastDreamStatus = + result.touchedTopics.length > 0 ? 'updated' : 'noop'; + // Mirror the manual /dream path's reset so the two write + // sites don't drift. The field is currently dead code on + // main (only ever written, never read) but keeping the two + // paths in sync avoids surprises if a future change starts + // reading it. + nextMetadata.recentSessionIdsSinceDream = []; + await writeDreamMetadata(params.projectRoot, nextMetadata); + } catch (metaError) { + const message = + metaError instanceof Error ? metaError.message : String(metaError); + debugLogger.warn( + `Failed to persist dream gating metadata for ${record.id}: ${message}`, + ); + this.update(record, { + metadata: { metadataWriteError: message }, + }); + } } finally { - await releaseDreamLock(params.projectRoot); + // Lock release errors are logged AND surfaced on the record's + // metadata so the user can see why subsequent dreams may be + // skipped as 'locked'. If releasing throws (e.g., EPERM on + // Windows, ENOENT race), letting it propagate to the outer + // catch would overwrite a successfully-completed dream with + // 'failed'. The on-disk lock will be cleaned up on the next + // session start via the staleness sweep, so swallowing the + // error here doesn't risk a permanently-stuck lock. + try { + await releaseDreamLock(params.projectRoot); + } catch (lockError) { + const message = + lockError instanceof Error ? lockError.message : String(lockError); + debugLogger.warn( + `Failed to release dream lock for task ${record.id}: ${message}. ` + + `Next scheduleDream() will force-clean the leaked lock.`, + ); + this.dreamLockReleaseFailed = true; + this.update(record, { + metadata: { lockReleaseError: message }, + }); + } } } catch (error) { + // User-cancel path: cancelTask already aborted the signal AND + // marked the record cancelled. The fork agent throws an abort + // error which lands here; don't overwrite with 'failed'. + if (abortSignal.aborted && record.status === 'cancelled') { + if (params.config) { + logMemoryDream( + params.config, + new MemoryDreamEvent({ + trigger: 'auto', + status: 'cancelled', + deduped_entries: 0, + touched_topics: [], + // Real elapsed time the cancelled dream consumed before + // the user stopped it — without this, latency histograms + // / p95 metrics would silently treat cancelled dreams as + // 0ms and skew toward the success path. + duration_ms: Date.now() - dreamStartMs, + }), + ); + } + return record; + } this.update(record, { status: 'failed', error: error instanceof Error ? error.message : String(error), }); } finally { this.dreamInFlightByKey.delete(dedupeKey); + this.dreamAbortControllers.delete(record.id); } return record; } diff --git a/packages/core/src/memory/recall.ts b/packages/core/src/memory/recall.ts index 5169385e9..8697496f9 100644 --- a/packages/core/src/memory/recall.ts +++ b/packages/core/src/memory/recall.ts @@ -131,6 +131,8 @@ export interface ResolveRelevantAutoMemoryPromptOptions { excludedFilePaths?: Iterable; limit?: number; recentTools?: readonly string[]; + /** When provided and aborted, suppresses logMemoryRecall telemetry for discarded results. */ + abortSignal?: AbortSignal; } export interface RelevantAutoMemoryPromptResult { @@ -168,7 +170,7 @@ export async function resolveRelevantAutoMemoryPromptForQuery( const limit = options.limit ?? MAX_RELEVANT_DOCS; if (query.trim().length === 0 || docs.length === 0 || limit <= 0) { - if (options.config) { + if (options.config && !options.abortSignal?.aborted) { logMemoryRecall( options.config, new MemoryRecallEvent({ @@ -195,36 +197,57 @@ export async function resolveRelevantAutoMemoryPromptForQuery( docs, limit, options.recentTools ?? [], + options.abortSignal, ); const strategy: RelevantAutoMemoryPromptResult['strategy'] = selectedDocs.length > 0 ? 'model' : 'none'; - logMemoryRecall( - options.config, - new MemoryRecallEvent({ - query_length: query.length, - docs_scanned: docs.length, - docs_selected: selectedDocs.length, - strategy, - duration_ms: Date.now() - t0, - }), - ); + if (!options.abortSignal?.aborted) { + logMemoryRecall( + options.config, + new MemoryRecallEvent({ + query_length: query.length, + docs_scanned: docs.length, + docs_selected: selectedDocs.length, + strategy, + duration_ms: Date.now() - t0, + }), + ); + } return { prompt: buildRelevantAutoMemoryPrompt(selectedDocs), selectedDocs, strategy, }; } catch (error) { - debugLogger.warn( - 'Model-driven auto-memory recall failed; falling back to heuristic selection.', - error, - ); + // Distinguish deadline-triggered cancellation from real model errors + // so oncall debugging is not misled by the fallback log. + if (error instanceof DOMException && error.name === 'AbortError') { + debugLogger.debug( + 'Model-driven auto-memory recall cancelled by deadline; heuristic result discarded.', + ); + } else { + debugLogger.warn( + 'Model-driven auto-memory recall failed; falling back to heuristic selection.', + error, + ); + } } } + // If the caller's abort signal is already set (e.g. deadline fired), skip the + // heuristic fallback — the result would be discarded anyway. + if (options.abortSignal?.aborted) { + return { + prompt: '', + selectedDocs: [], + strategy: 'none', + }; + } + const selectedDocs = selectRelevantAutoMemoryDocuments(query, docs, limit); const strategy: RelevantAutoMemoryPromptResult['strategy'] = selectedDocs.length > 0 ? 'heuristic' : 'none'; - if (options.config) { + if (options.config && !options.abortSignal?.aborted) { logMemoryRecall( options.config, new MemoryRecallEvent({ diff --git a/packages/core/src/memory/relevanceSelector.test.ts b/packages/core/src/memory/relevanceSelector.test.ts index 7b86aebc9..1dcc6a1fc 100644 --- a/packages/core/src/memory/relevanceSelector.test.ts +++ b/packages/core/src/memory/relevanceSelector.test.ts @@ -5,7 +5,6 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import type { Config } from '../config/config.js'; import { runSideQuery } from '../utils/sideQuery.js'; import type { ScannedAutoMemoryDocument } from './scan.js'; import { selectRelevantAutoMemoryDocumentsByModel } from './relevanceSelector.js'; @@ -38,7 +37,9 @@ const docs: ScannedAutoMemoryDocument[] = [ ]; describe('selectRelevantAutoMemoryDocumentsByModel', () => { - const mockConfig = {} as Config; + const mockConfig = {} as Parameters< + typeof selectRelevantAutoMemoryDocumentsByModel + >[0]; beforeEach(() => { vi.clearAllMocks(); @@ -76,6 +77,55 @@ describe('selectRelevantAutoMemoryDocumentsByModel', () => { expect(runSideQuery).not.toHaveBeenCalled(); }); + it('forwards caller abort signal to runSideQuery combined with timeout', async () => { + const callerController = new AbortController(); + let capturedSignal: AbortSignal | undefined; + + vi.mocked(runSideQuery).mockImplementation(async (_config, opts) => { + capturedSignal = opts.abortSignal; + return { selected_memories: [] }; + }); + + await selectRelevantAutoMemoryDocumentsByModel( + mockConfig, + 'check preferences', + docs, + 2, + [], + callerController.signal, + ); + + expect(runSideQuery).toHaveBeenCalledTimes(1); + expect(capturedSignal).toBeDefined(); + expect(capturedSignal!.aborted).toBe(false); + + callerController.abort(); + + await vi.waitFor(() => { + expect(capturedSignal!.aborted).toBe(true); + }); + }); + + it('uses timeout-only abort signal when no caller signal provided', async () => { + vi.mocked(runSideQuery).mockResolvedValue({ + selected_memories: [], + }); + + await selectRelevantAutoMemoryDocumentsByModel( + mockConfig, + 'check preferences', + docs, + 2, + ); + + expect(runSideQuery).toHaveBeenCalledWith( + mockConfig, + expect.objectContaining({ + abortSignal: expect.any(AbortSignal), + }), + ); + }); + it('throws when selector returns unknown relative paths', async () => { vi.mocked(runSideQuery).mockImplementation(async (_config, options) => { const error = options.validate?.({ diff --git a/packages/core/src/memory/relevanceSelector.ts b/packages/core/src/memory/relevanceSelector.ts index b457a965e..69f6fe719 100644 --- a/packages/core/src/memory/relevanceSelector.ts +++ b/packages/core/src/memory/relevanceSelector.ts @@ -58,6 +58,7 @@ export async function selectRelevantAutoMemoryDocumentsByModel( docs: ScannedAutoMemoryDocument[], limit: number, recentTools: readonly string[] = [], + callerAbortSignal?: AbortSignal, ): Promise { if (docs.length === 0 || limit <= 0 || query.trim().length === 0) { return []; @@ -90,7 +91,9 @@ export async function selectRelevantAutoMemoryDocumentsByModel( purpose: 'auto-memory-recall', contents, schema: RESPONSE_SCHEMA, - abortSignal: AbortSignal.timeout(5_000), + abortSignal: callerAbortSignal + ? AbortSignal.any([AbortSignal.timeout(2_000), callerAbortSignal]) + : AbortSignal.timeout(2_000), systemInstruction: SELECT_MEMORIES_SYSTEM_PROMPT, config: { temperature: 0, diff --git a/packages/core/src/services/chatCompressionService.test.ts b/packages/core/src/services/chatCompressionService.test.ts index 417988331..3a3520cce 100644 --- a/packages/core/src/services/chatCompressionService.test.ts +++ b/packages/core/src/services/chatCompressionService.test.ts @@ -75,15 +75,17 @@ describe('findCompressSplitPoint', () => { expect(findCompressSplitPoint(history, 0.8)).toBe(4); }); - it('should return earlier splitpoint if no valid ones are after threshhold', () => { + it('compresses everything before the trailing in-flight functionCall', () => { const history: Content[] = [ { role: 'user', parts: [{ text: 'This is the first message.' }] }, { role: 'model', parts: [{ text: 'This is the second message.' }] }, { role: 'user', parts: [{ text: 'This is the third message.' }] }, { role: 'model', parts: [{ functionCall: { name: 'foo', args: {} } }] }, ]; - // Can't return 4 because the previous item has a function call. - expect(findCompressSplitPoint(history, 0.99)).toBe(2); + // Trailing m+fc is in-flight; no preceding (m+fc, u+fr) pair to retain, + // so the in-flight fallback compresses everything except the trailing fc. + // The kept slice starts with m+fc; callers bridge with a synthetic user. + expect(findCompressSplitPoint(history, 0.99)).toBe(3); }); it('should handle a history with only one item', () => { @@ -143,7 +145,7 @@ describe('findCompressSplitPoint', () => { expect(findCompressSplitPoint(history, 0.7)).toBe(5); }); - it('should return primary split point when tool completions have no subsequent regular user message', () => { + it('retains last K complete tool rounds when no fresh user splits past target', () => { const history: Content[] = [ { role: 'user', parts: [{ text: 'Fix this' }] }, { @@ -181,14 +183,13 @@ describe('findCompressSplitPoint', () => { parts: [{ functionCall: { name: 'write1', args: {} } }], }, ]; - // Only one non-functionResponse user message (index 0) -> lastSplitPoint=0 - // Last message has functionCall -> can't compress everything - // historyToKeep must start with a regular user message, so split at 0 - // (compress nothing) is the only valid option. - expect(findCompressSplitPoint(history, 0.7)).toBe(0); + // 2 complete (m+fc, u+fr) pairs precede the trailing fc → retain both + // pairs + trailing fc = last 5 entries; compress index 0 (the task). + // Pre-refactor this returned 0 (NOOP); now it compresses-most. + expect(findCompressSplitPoint(history, 0.7)).toBe(history.length - 5); }); - it('should prefer primary split point when tool completions yield no valid user-starting split', () => { + it('prefers compress-most over lastSplitPoint when scan finds no clean split past target', () => { const longContent = 'a'.repeat(10000); const history: Content[] = [ { role: 'user', parts: [{ text: 'Fix bug A' }] }, @@ -229,13 +230,12 @@ describe('findCompressSplitPoint', () => { parts: [{ functionCall: { name: 'write1', args: {} } }], }, ]; - // Primary split points at 0 and 2 (regular user messages before the bulky tool outputs) - // Last message has functionCall -> can't compress everything - // Should return lastSplitPoint=2 (last valid primary split point) - expect(findCompressSplitPoint(history, 0.7)).toBe(2); + // 2 complete pairs before the trailing fc → retain both + trailing = 5 + // entries kept. Pre-refactor returned lastSplitPoint=2 (compress less). + expect(findCompressSplitPoint(history, 0.7)).toBe(history.length - 5); }); - it('should still prefer primary split point when it is better', () => { + it('compresses-most via in-flight fallback when scan never crosses the target', () => { const history: Content[] = [ { role: 'user', parts: [{ text: 'msg1' }] }, { role: 'model', parts: [{ text: 'resp1' }] }, @@ -266,10 +266,93 @@ describe('findCompressSplitPoint', () => { parts: [{ functionCall: { name: 'tool2', args: {} } }], }, ]; - // Primary split points: 0, 2, 5, 7 - // Last message has functionCall -> can't compress everything - // At 0.99 fraction, lastSplitPoint should be 7 - expect(findCompressSplitPoint(history, 0.99)).toBe(7); + // The entry before the trailing fc is a fresh user (msg4), not a u+fr, + // so the pair walk stops with 0 pairs found → retain only the trailing + // fc, compress everything else. Pre-refactor returned lastSplitPoint=7. + expect(findCompressSplitPoint(history, 0.99)).toBe(history.length - 1); + }); +}); + +describe('findCompressSplitPoint — in-flight fallback', () => { + const userTask = (text: string): Content => ({ + role: 'user', + parts: [{ text }], + }); + const modelText = (text: string): Content => ({ + role: 'model', + parts: [{ text }], + }); + const modelFc = (name: string): Content => ({ + role: 'model', + parts: [{ functionCall: { name, args: {} } }], + }); + const userFr = (name: string): Content => ({ + role: 'user', + parts: [{ functionResponse: { name, response: { result: 'x' } } }], + }); + + // Subagent-shaped history at compression check time: env bootstrap, task, + // alternating tool rounds, ending in a trailing in-flight model+fc whose + // functionResponse hasn't been pushed yet. The scan finds no clean split + // past the target fraction, so the in-flight fallback decides the index. + it('compresses everything except trailing fc + most recent retainCount pairs', () => { + const history = [ + userTask('env'), + modelText('env-ack'), + userTask('task'), + modelFc('a'), + userFr('a'), + modelFc('b'), + userFr('b'), + modelFc('c'), + userFr('c'), + modelFc('d'), + userFr('d'), + modelFc('trailing'), + ]; + // Default retainCount = 2 → keep last 5 (2 pairs + trailing). + expect(findCompressSplitPoint(history, 0.7)).toBe(history.length - 5); + }); + + it('retains all pairs when fewer than retainCount exist', () => { + const history = [ + userTask('env'), + modelText('env-ack'), + userTask('task'), + modelFc('a'), + userFr('a'), + modelFc('trailing'), + ]; + // Only 1 complete pair → keep last 3 (1 pair + trailing). + expect(findCompressSplitPoint(history, 0.7)).toBe(history.length - 3); + }); + + it('retains just the trailing fc when no complete pairs precede it', () => { + const history = [ + userTask('env'), + modelText('env-ack'), + userTask('task'), + modelFc('trailing'), + ]; + // No complete pairs → keep only the trailing fc. + expect(findCompressSplitPoint(history, 0.7)).toBe(history.length - 1); + }); + + it('respects an explicit retainCount override', () => { + const history = [ + userTask('env'), + modelText('env-ack'), + userTask('task'), + modelFc('a'), + userFr('a'), + modelFc('b'), + userFr('b'), + modelFc('c'), + userFr('c'), + modelFc('trailing'), + ]; + // Override retainCount to 1 → keep last 3 (1 pair + trailing). + expect(findCompressSplitPoint(history, 0.7, 1)).toBe(history.length - 3); }); }); @@ -313,14 +396,14 @@ describe('ChatCompressionService', () => { it('should return NOOP if history is empty', async () => { vi.mocked(mockChat.getHistory).mockReturnValue([]); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(result.newHistory).toBeNull(); }); @@ -329,14 +412,14 @@ describe('ChatCompressionService', () => { vi.mocked(mockChat.getHistory).mockReturnValue([ { role: 'user', parts: [{ text: 'hi' }] }, ]); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - true, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: true, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(result.newHistory).toBeNull(); }); @@ -349,14 +432,14 @@ describe('ChatCompressionService', () => { vi.mocked(tokenLimit).mockReturnValue(1000); // Threshold is 0.7 * 1000 = 700. 600 < 700, so NOOP. - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(result.newHistory).toBeNull(); }); @@ -377,14 +460,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info).toMatchObject({ compressionStatus: CompressionStatus.NOOP, @@ -394,14 +477,14 @@ describe('ChatCompressionService', () => { expect(mockGenerateContent).not.toHaveBeenCalled(); expect(tokenLimit).not.toHaveBeenCalled(); - const forcedResult = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const forcedResult = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(forcedResult.info).toMatchObject({ compressionStatus: CompressionStatus.NOOP, originalTokenCount: 0, @@ -438,14 +521,14 @@ describe('ChatCompressionService', () => { } as unknown as ContentGenerator); // force=true bypasses the token threshold gate so we exercise the 5% guard - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(result.newHistory).toBeNull(); @@ -485,14 +568,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); expect(result.info.newTokenCount).toBe(250); // 800 - (1600 - 1000) + 50 @@ -539,14 +622,15 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, // forced - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + // forced + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); expect(result.newHistory).not.toBeNull(); @@ -586,14 +670,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe( CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, @@ -629,14 +713,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe( CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, @@ -668,14 +752,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe( CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY, @@ -707,14 +791,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe( CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY, @@ -749,14 +833,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe( CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, @@ -801,14 +885,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Should still complete compression despite hook error expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); @@ -860,14 +944,15 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - await service.compress( - mockChat, - mockPromptId, - true, // force = true -> Manual trigger - mockModel, - mockConfig, - false, - ); + await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + // force = true -> Manual trigger + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(mockFirePreCompactEvent).toHaveBeenCalledWith( PreCompactTrigger.Manual, @@ -910,14 +995,15 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - await service.compress( - mockChat, - mockPromptId, - false, // force = false -> Auto trigger - mockModel, - mockConfig, - false, - ); + await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + // force = false -> Auto trigger + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(mockFirePreCompactEvent).toHaveBeenCalledWith( PreCompactTrigger.Auto, @@ -929,14 +1015,14 @@ describe('ChatCompressionService', () => { it('should not fire PreCompact hook when history is empty', async () => { vi.mocked(mockChat.getHistory).mockReturnValue([]); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(mockFirePreCompactEvent).not.toHaveBeenCalled(); @@ -952,14 +1038,14 @@ describe('ChatCompressionService', () => { contextPercentageThreshold: 0, }); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(mockFirePreCompactEvent).not.toHaveBeenCalled(); @@ -976,14 +1062,14 @@ describe('ChatCompressionService', () => { ); vi.mocked(tokenLimit).mockReturnValue(1000); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); expect(mockFirePreCompactEvent).not.toHaveBeenCalled(); @@ -1027,14 +1113,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Should still complete compression despite hook error expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); @@ -1084,14 +1170,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // PreCompact should be called before SessionStart expect(callOrder).toEqual(['PreCompact', 'SessionStart']); @@ -1133,14 +1219,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Should still complete compression without hook expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); @@ -1195,14 +1281,15 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - await service.compress( - mockChat, - mockPromptId, - true, // force = true -> Manual trigger - mockModel, - mockConfig, - false, - ); + await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + // force = true -> Manual trigger + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(mockFirePostCompactEvent).toHaveBeenCalledWith( PostCompactTrigger.Manual, @@ -1245,14 +1332,15 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - await service.compress( - mockChat, - mockPromptId, - false, // force = false -> Auto trigger - mockModel, - mockConfig, - false, - ); + await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + // force = false -> Auto trigger + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(mockFirePostCompactEvent).toHaveBeenCalledWith( PostCompactTrigger.Auto, @@ -1292,14 +1380,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); expect(result.info.compressionStatus).toBe( CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY, @@ -1345,14 +1433,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Should still complete compression despite hook error expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); @@ -1405,14 +1493,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Hooks should be called in order: PreCompact -> SessionStart -> PostCompact expect(callOrder).toEqual(['PreCompact', 'SessionStart', 'PostCompact']); @@ -1454,14 +1542,14 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Should still complete compression without hook expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); @@ -1535,14 +1623,15 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - true, // force=true (manual /compress) - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: true, + // force=true (manual /compress) + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); // Should compress successfully — orphaned funcCall is stripped first, then // normal compression runs on the remaining history, historyToKeep is empty @@ -1558,10 +1647,12 @@ describe('ChatCompressionService', () => { expect(callArg.contents.length).toBe(history.length); // (history.length - 1) messages + 1 instruction }); - it('should NOT compress orphaned funcCall when force=false (auto-compress)', async () => { - // Auto-compress fires BEFORE the matching funcResponse is sent back to the - // model. Compressing the funcCall away would orphan the upcoming funcResponse - // and cause an API error. So force=false must NOT take this path. + it('compresses-most without orphaning when last entry is in-flight funcCall (auto-compress)', async () => { + // Auto-compress fires BEFORE the matching funcResponse is sent back to + // the model. The trailing funcCall must be retained (its response is + // coming); the in-flight fallback compresses everything safely before + // it. Pre-refactor this returned NOOP, leaving the chat to grow until + // it 400'd. const history: Content[] = [ { role: 'user', parts: [{ text: 'Fix all TypeScript errors.' }] }, { @@ -1586,7 +1677,6 @@ describe('ChatCompressionService', () => { }, ]; vi.mocked(mockChat.getHistory).mockReturnValue(history); - // Use a token count above threshold to ensure auto-compress isn't skipped vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( 800, ); @@ -1595,23 +1685,195 @@ describe('ChatCompressionService', () => { contextWindowSize: 1000, } as unknown as ReturnType); + const mockGenerateContent = vi.fn().mockResolvedValue({ + candidates: [ + { content: { parts: [{ text: 'state snapshot summary' }] } }, + ], + usageMetadata: { + promptTokenCount: 2000, + candidatesTokenCount: 50, + totalTokenCount: 2050, + }, + } as unknown as GenerateContentResponse); + vi.mocked(mockConfig.getContentGenerator).mockReturnValue({ + generateContent: mockGenerateContent, + } as unknown as ContentGenerator); + + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); + + expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); + expect(mockGenerateContent).toHaveBeenCalledTimes(1); + // Trailing in-flight functionCall is preserved last in the kept slice + // so the upcoming functionResponse pairs with it. + const newHistory = result.newHistory!; + const last = newHistory[newHistory.length - 1]; + expect(last.role).toBe('model'); + expect(last.parts?.some((p) => p.functionCall)).toBe(true); + // Strict role alternation throughout. + for (let i = 1; i < newHistory.length; i++) { + expect(newHistory[i].role).not.toBe(newHistory[i - 1].role); + } + }); + }); + + describe('tool-loop subagent absorption', () => { + // The fresh-user split heuristic produces a tiny compress slice when the + // history is dominated by tool rounds (every user past the task is a + // functionResponse). Without absorption, MIN_COMPRESSION_FRACTION would + // NOOP every send and the subagent eventually hits the 400 it was meant + // to avoid. + it('compresses by absorbing older tool rounds when fresh-user split is too small', async () => { + const FILLER = 'A'.repeat(20_000); + // Auto-compress fires BEFORE the next functionResponse is pushed, so + // the trailing entry is always a model+functionCall with no match yet. + // Build a history with N complete pairs followed by one trailing fc. + const buildHistory = (completePairs: number): Content[] => { + const h: Content[] = [ + { role: 'user', parts: [{ text: 'env-bootstrap' }] }, + { role: 'model', parts: [{ text: 'env-ack' }] }, + { role: 'user', parts: [{ text: 'task: explore' }] }, + ]; + for (let r = 0; r < completePairs; r++) { + h.push({ + role: 'model', + parts: [ + { text: `round ${r}: ${FILLER}` }, + { functionCall: { name: 'glob', args: { pattern: '**/*.md' } } }, + ], + }); + h.push({ + role: 'user', + parts: [ + { + functionResponse: { name: 'glob', response: { result: 'x' } }, + }, + ], + }); + } + // Trailing model+fc whose response is about to be sent. + h.push({ + role: 'model', + parts: [ + { text: `round ${completePairs}: ${FILLER}` }, + { functionCall: { name: 'glob', args: { pattern: '**/*.md' } } }, + ], + }); + return h; + }; + + // Five complete tool rounds + 1 trailing fc → 5 pairs in keep; absorbs + // 3 older pairs and retains the 2 most recent (plus the trailing fc). + vi.mocked(mockChat.getHistory).mockReturnValue(buildHistory(5)); + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + 80_000, + ); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + model: 'gemini-pro', + contextWindowSize: 100_000, + } as unknown as ReturnType); + + const mockGenerateContent = vi.fn().mockResolvedValue({ + candidates: [ + { content: { parts: [{ text: 'state snapshot summary' }] } }, + ], + usageMetadata: { + promptTokenCount: 60_000, + candidatesTokenCount: 200, + totalTokenCount: 60_200, + }, + } as unknown as GenerateContentResponse); + vi.mocked(mockConfig.getContentGenerator).mockReturnValue({ + generateContent: mockGenerateContent, + } as unknown as ContentGenerator); + + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); + + expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED); + expect(result.newHistory).not.toBeNull(); + expect(mockGenerateContent).toHaveBeenCalledTimes(1); + + const newHistory = result.newHistory!; + // [summary_user, summary_ack_model, continuation_bridge_user, ...keep] + // where keep starts with the retained model+functionCall. + expect(newHistory[0].role).toBe('user'); + expect(newHistory[0].parts?.[0].text).toBe('state snapshot summary'); + expect(newHistory[1].role).toBe('model'); + expect(newHistory[2].role).toBe('user'); + expect(newHistory[2].parts?.[0].text).toMatch(/Continue/); + // Retained two complete pairs (4 entries) + trailing model+fc = 5. + expect(newHistory.slice(3)).toHaveLength(5); + expect(newHistory[3].role).toBe('model'); + expect(newHistory[3].parts?.some((p) => p.functionCall)).toBe(true); + expect(newHistory[4].role).toBe('user'); + expect(newHistory[4].parts?.some((p) => p.functionResponse)).toBe(true); + // Trailing model+fc remains last so the upcoming functionResponse pushed + // by sendMessageStream pairs with it correctly. + const last = newHistory[newHistory.length - 1]; + expect(last.role).toBe('model'); + expect(last.parts?.some((p) => p.functionCall)).toBe(true); + + // Strict role alternation throughout the new history. + for (let i = 1; i < newHistory.length; i++) { + expect(newHistory[i].role).not.toBe(newHistory[i - 1].role); + } + }); + + it('NOOPs when the keep slice has too few tool rounds to absorb', async () => { + const FILLER = 'A'.repeat(20_000); + const history: Content[] = [ + { role: 'user', parts: [{ text: 'env-bootstrap' }] }, + { role: 'model', parts: [{ text: 'env-ack' }] }, + { role: 'user', parts: [{ text: 'task' }] }, + { + role: 'model', + parts: [ + { text: FILLER }, + { functionCall: { name: 'glob', args: {} } }, + ], + }, + ]; + vi.mocked(mockChat.getHistory).mockReturnValue(history); + // Set originalTokenCount above the threshold gate (0.7 * 30000 = 21000) + // so the test actually exercises findCompressSplitPoint and the + // MIN_COMPRESSION_FRACTION decision rather than short-circuiting at + // the cheap-gate. + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + 22_000, + ); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + model: 'gemini-pro', + contextWindowSize: 30_000, + } as unknown as ReturnType); + const mockGenerateContent = vi.fn(); vi.mocked(mockConfig.getContentGenerator).mockReturnValue({ generateContent: mockGenerateContent, } as unknown as ContentGenerator); - const result = await service.compress( - mockChat, - mockPromptId, - false, // force=false (auto-compress) - mockModel, - mockConfig, - false, - ); + const result = await service.compress(mockChat, { + promptId: mockPromptId, + force: false, + model: mockModel, + config: mockConfig, + hasFailedCompressionAttempt: false, + originalTokenCount: uiTelemetryService.getLastPromptTokenCount(), + }); - // Must return NOOP — compressing would orphan the upcoming funcResponse expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP); - expect(result.newHistory).toBeNull(); expect(mockGenerateContent).not.toHaveBeenCalled(); }); }); diff --git a/packages/core/src/services/chatCompressionService.ts b/packages/core/src/services/chatCompressionService.ts index 876d051ac..4e6b19bd3 100644 --- a/packages/core/src/services/chatCompressionService.ts +++ b/packages/core/src/services/chatCompressionService.ts @@ -8,7 +8,6 @@ import type { Content } from '@google/genai'; import type { Config } from '../config/config.js'; import type { GeminiChat } from '../core/geminiChat.js'; import { type ChatCompressionInfo, CompressionStatus } from '../core/turn.js'; -import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { DEFAULT_TOKEN_LIMIT } from '../core/tokenLimits.js'; import { getCompressionPrompt } from '../core/prompts.js'; import { getResponseText } from '../utils/partUtils.js'; @@ -40,15 +39,79 @@ export const COMPRESSION_PRESERVE_THRESHOLD = 0.3; */ export const MIN_COMPRESSION_FRACTION = 0.05; +/** + * When the trailing entry is an in-flight `model+functionCall` and the regular + * scan finds no clean split past the target fraction, the splitter falls back + * to compressing everything except the last few entries. This constant sets + * how many most-recent complete `(model+functionCall, user+functionResponse)` + * tool rounds are retained as working context (the trailing in-flight call is + * always retained on top of these). + */ +export const TOOL_ROUND_RETAIN_COUNT = 2; + +const hasFunctionCall = (content: Content | undefined): boolean => + !!content?.parts?.some((part) => !!part.functionCall); + +const hasFunctionResponse = (content: Content | undefined): boolean => + !!content?.parts?.some((part) => !!part.functionResponse); + +/** + * Walk backward from the trailing in-flight `model+functionCall` and return + * the index after which the most-recent `retainCount` complete tool-round + * pairs sit (plus the trailing fc itself). Used by the splitter's in-flight + * fallback path. Stops counting at the first non-pair encountered, so the + * retain count is best-effort: if there are fewer complete pairs than + * requested, all of them are retained. + */ +function splitPointRetainingTrailingPairs( + contents: Content[], + retainCount: number, +): number { + let pairsFound = 0; + let i = contents.length - 2; + while (i >= 1 && pairsFound < retainCount) { + if (hasFunctionCall(contents[i - 1]) && hasFunctionResponse(contents[i])) { + pairsFound += 1; + i -= 2; + } else { + break; + } + } + return contents.length - (2 * pairsFound + 1); +} + /** * Returns the index of the oldest item to keep when compressing. May return * contents.length which indicates that everything should be compressed. * + * The algorithm has two phases: + * + * 1. **Scan:** walk left-to-right looking for the first non-functionResponse + * user message that lands past `fraction` of total chars. That's the + * "clean" split — the kept slice starts with a fresh user prompt. + * + * 2. **Fallbacks** (no clean split found): the gate that gets us here has + * already decided we need to compress, so all three fallbacks bias toward + * *more* compression rather than less: + * + * - last entry is `model` without functionCall → compress everything. + * - last entry is `user` with functionResponse → compress everything (the + * trailing tool round is complete; no orphans). + * - last entry is `model` with functionCall (in-flight) → compress + * everything except the trailing call plus the last `retainCount` + * complete tool rounds. The kept slice may start with `model+fc`; + * callers must inject a synthetic continuation user message between + * `summary_ack_model` and the kept slice to preserve role alternation. + * + * The pre-fallback returns of `lastSplitPoint` (compress less) only happen + * for malformed histories that don't end in user/model. + * * Exported for testing purposes. */ export function findCompressSplitPoint( contents: Content[], fraction: number, + retainCount = TOOL_ROUND_RETAIN_COUNT, ): number { if (fraction <= 0 || fraction >= 1) { throw new Error('Fraction must be between 0 and 1'); @@ -58,14 +121,11 @@ export function findCompressSplitPoint( const totalCharCount = charCounts.reduce((a, b) => a + b, 0); const targetCharCount = totalCharCount * fraction; - let lastSplitPoint = 0; // 0 is always valid (compress nothing) + let lastSplitPoint = 0; let cumulativeCharCount = 0; for (let i = 0; i < contents.length; i++) { const content = contents[i]; - if ( - content.role === 'user' && - !content.parts?.some((part) => !!part.functionResponse) - ) { + if (content.role === 'user' && !hasFunctionResponse(content)) { if (cumulativeCharCount >= targetCharCount) { return i; } @@ -74,48 +134,57 @@ export function findCompressSplitPoint( cumulativeCharCount += charCounts[i]; } - // We found no split points after targetCharCount. - // Check if it's safe to compress everything. const lastContent = contents[contents.length - 1]; - if ( - lastContent?.role === 'model' && - !lastContent?.parts?.some((part) => part.functionCall) - ) { + if (lastContent?.role === 'model') { + if (!hasFunctionCall(lastContent)) return contents.length; + return splitPointRetainingTrailingPairs(contents, retainCount); + } + if (lastContent?.role === 'user' && hasFunctionResponse(lastContent)) { return contents.length; } - // Also safe to compress everything if the last message completes a tool call - // sequence (all function calls have matching responses). - if ( - lastContent?.role === 'user' && - lastContent?.parts?.some((part) => !!part.functionResponse) - ) { - return contents.length; - } - return lastSplitPoint; } +export interface CompressOptions { + promptId: string; + force: boolean; + model: string; + config: Config; + /** + * Whether a previous unforced compression attempt failed for this chat. + * Suppresses auto-compaction; manual `/compress` (force=true) overrides. + */ + hasFailedCompressionAttempt: boolean; + /** + * Most recent prompt token count for this chat. Compared against + * `threshold * contextWindowSize` for the auto-compaction gate. Callers + * source this from the per-chat counter (main session, subagents alike) — + * the service does not read or write any global telemetry. + */ + originalTokenCount: number; + signal?: AbortSignal; +} + export class ChatCompressionService { async compress( chat: GeminiChat, - promptId: string, - force: boolean, - model: string, - config: Config, - hasFailedCompressionAttempt: boolean, - signal?: AbortSignal, + opts: CompressOptions, ): Promise<{ newHistory: Content[] | null; info: ChatCompressionInfo }> { - const curatedHistory = chat.getHistory(true); + const { + promptId, + force, + model, + config, + hasFailedCompressionAttempt, + originalTokenCount, + signal, + } = opts; const threshold = config.getChatCompression()?.contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD; - // Regardless of `force`, don't do anything if the history is empty. - if ( - curatedHistory.length === 0 || - threshold <= 0 || - (hasFailedCompressionAttempt && !force) - ) { + // Cheap gates first — these don't need the curated history. + if (threshold <= 0 || (hasFailedCompressionAttempt && !force)) { return { newHistory: null, info: { @@ -126,9 +195,9 @@ export class ChatCompressionService { }; } - const originalTokenCount = uiTelemetryService.getLastPromptTokenCount(); - - // Don't compress if not forced and we are under the limit. + // Don't compress if not forced and we are under the limit. This is the + // steady-state path on every send; we want to exit before paying for the + // full `getHistory(true)` clone below. if (!force) { const contextLimit = config.getContentGeneratorConfig()?.contextWindowSize ?? @@ -145,6 +214,18 @@ export class ChatCompressionService { } } + const curatedHistory = chat.getHistory(true); + if (curatedHistory.length === 0) { + return { + newHistory: null, + info: { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }, + }; + } + // Fire PreCompact hook before compression begins const hookSystem = config.getHookSystem(); if (hookSystem) { @@ -181,6 +262,10 @@ export class ChatCompressionService { const historyToCompress = historyForSplit.slice(0, splitPoint); const historyToKeep = historyForSplit.slice(splitPoint); + // The in-flight fallback path may produce a kept slice starting with + // model+functionCall; the post-summary history needs a synthetic user + // between the summary's model_ack and the kept entries. + const keepNeedsContinuationBridge = historyToKeep[0]?.role === 'model'; if (historyToCompress.length === 0) { return { @@ -196,10 +281,6 @@ export class ChatCompressionService { // Guard: if historyToCompress is too small relative to the total history, // skip compression. This prevents futile API calls where the model receives // almost no context and generates a useless "summary" that inflates tokens. - // - // Note: findCompressSplitPoint already computes charCounts internally but - // returns only the split index. We intentionally recompute here to keep - // the function signature simple; this is a minor, acceptable duplication. const compressCharCount = historyToCompress.reduce( (sum, c) => sum + JSON.stringify(c).length, 0, @@ -274,6 +355,22 @@ export class ChatCompressionService { role: 'model', parts: [{ text: 'Got it. Thanks for the additional context!' }], }, + // When the kept slice starts with model+functionCall (because + // tool-round absorption pulled the only fresh user message into + // compress), inject a synthetic continuation prompt so the joined + // history alternates correctly. + ...(keepNeedsContinuationBridge + ? [ + { + role: 'user' as const, + parts: [ + { + text: 'Continue with the prior task using the context above.', + }, + ], + }, + ] + : []), ...historyToKeep, ]; @@ -339,8 +436,6 @@ export class ChatCompressionService { }, }; } else { - uiTelemetryService.setLastPromptTokenCount(newTokenCount); - // Fire SessionStart event after successful compression try { const permissionMode = String( diff --git a/packages/core/src/telemetry/metrics.ts b/packages/core/src/telemetry/metrics.ts index bcd577a28..727178049 100644 --- a/packages/core/src/telemetry/metrics.ts +++ b/packages/core/src/telemetry/metrics.ts @@ -959,7 +959,7 @@ export function recordMemoryDreamMetrics( durationMs: number, attrs: { trigger: 'auto' | 'manual'; - status: 'updated' | 'noop' | 'failed'; + status: 'updated' | 'noop' | 'failed' | 'cancelled'; deduped_entries: number; }, ): void { diff --git a/packages/core/src/telemetry/types.ts b/packages/core/src/telemetry/types.ts index 764147297..fb604d49a 100644 --- a/packages/core/src/telemetry/types.ts +++ b/packages/core/src/telemetry/types.ts @@ -1204,7 +1204,7 @@ export class MemoryDreamEvent implements BaseTelemetryEvent { 'event.timestamp': string; /** 'auto' = scheduler-triggered; 'manual' = user ran /dream */ trigger: 'auto' | 'manual'; - status: 'updated' | 'noop' | 'failed'; + status: 'updated' | 'noop' | 'failed' | 'cancelled'; deduped_entries: number; touched_topics_count: number; touched_topics: string; @@ -1212,7 +1212,7 @@ export class MemoryDreamEvent implements BaseTelemetryEvent { constructor(params: { trigger: 'auto' | 'manual'; - status: 'updated' | 'noop' | 'failed'; + status: 'updated' | 'noop' | 'failed' | 'cancelled'; deduped_entries: number; touched_topics: string[]; duration_ms: number; diff --git a/packages/core/src/tools/agent/agent.test.ts b/packages/core/src/tools/agent/agent.test.ts index d385fcdfb..150b0f03c 100644 --- a/packages/core/src/tools/agent/agent.test.ts +++ b/packages/core/src/tools/agent/agent.test.ts @@ -93,7 +93,24 @@ describe('AgentTool', () => { // Setup fake timers vi.useFakeTimers(); - // Create mock config + // Create mock config. The outer describe covers foreground execution + // paths, which now register/unregister in the BackgroundTaskRegistry + // to surface the run in the pill+dialog. A no-op stub registry is + // enough for these tests — they don't assert on registry behavior. + const stubRegistry = { + register: vi.fn(), + unregisterForeground: vi.fn(), + complete: vi.fn(), + fail: vi.fn(), + finalizeCancelled: vi.fn(), + finalizeCancellationIfPending: vi.fn(), + cancel: vi.fn(), + get: vi.fn(), + getAll: vi.fn().mockReturnValue([]), + drainMessages: vi.fn().mockReturnValue([]), + queueMessage: vi.fn(), + appendActivity: vi.fn(), + }; config = { getProjectRoot: vi.fn().mockReturnValue('/test/project'), getSessionId: vi.fn().mockReturnValue('test-session-id'), @@ -104,6 +121,7 @@ describe('AgentTool', () => { getTranscriptPath: vi.fn().mockReturnValue('/test/transcript'), getApprovalMode: vi.fn().mockReturnValue('default'), isTrustedFolder: vi.fn().mockReturnValue(true), + getBackgroundTaskRegistry: vi.fn().mockReturnValue(stubRegistry), } as unknown as Config; changeListeners = []; @@ -412,9 +430,13 @@ describe('AgentTool', () => { expect.any(Object), // config (may be approval-mode override) expect.any(Object), // eventEmitter parameter ); + // Foreground subagents now run with a composed AbortSignal so the + // dialog's per-agent cancel can abort just this child without aborting + // the parent turn. The signal received by the subagent is the + // controller's signal, not whatever the caller passed in. expect(mockAgent.execute).toHaveBeenCalledWith( mockContextState, - undefined, // signal parameter (undefined when not provided) + expect.any(AbortSignal), ); const llmText = partToString(result.llmContent); @@ -738,7 +760,10 @@ describe('AgentTool', () => { expect.stringContaining('file-search-'), 'file-search', PermissionMode.AutoEdit, - undefined, + // Foreground subagents now run with a composed signal (so the + // dialog can cancel just this child) — the hook receives the + // composed signal, not the caller-supplied one. + expect.any(AbortSignal), ); }); @@ -920,7 +945,8 @@ describe('AgentTool', () => { 'Task completed successfully', false, PermissionMode.AutoEdit, - undefined, + // Foreground subagents now run with a composed signal. + expect.any(AbortSignal), ); }); @@ -965,7 +991,8 @@ describe('AgentTool', () => { 'Task completed successfully', true, PermissionMode.AutoEdit, - undefined, + // Foreground subagents now run with a composed signal. + expect.any(AbortSignal), ); }); @@ -1419,10 +1446,12 @@ describe('AgentTool', () => { let mockContextState: ContextState; let mockRegistry: { register: ReturnType; + unregisterForeground: ReturnType; complete: ReturnType; fail: ReturnType; finalizeCancelled: ReturnType; drainMessages: ReturnType; + appendActivity: ReturnType; }; const bgSubagent: SubagentConfig = { @@ -1454,10 +1483,12 @@ describe('AgentTool', () => { mockRegistry = { register: vi.fn(), + unregisterForeground: vi.fn(), complete: vi.fn(), fail: vi.fn(), finalizeCancelled: vi.fn(), drainMessages: vi.fn().mockReturnValue([]), + appendActivity: vi.fn(), }; vi.mocked(config.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT); @@ -1597,7 +1628,57 @@ describe('AgentTool', () => { const llmText = partToString(result.llmContent); expect(llmText).not.toContain('Background agent launched'); - expect(mockRegistry.register).not.toHaveBeenCalled(); + // Foreground subagents register in the same registry with + // flavor: 'foreground' so the pill+dialog can surface them while + // the parent's tool-call awaits, then unregister in the finally + // path once the call returns. (The tool-result is the durable + // record — the entry does not persist.) + expect(mockRegistry.register).toHaveBeenCalledWith( + expect.objectContaining({ + flavor: 'foreground', + description: 'Search files', + subagentType: 'file-search', + status: 'running', + }), + ); + expect(mockRegistry.unregisterForeground).toHaveBeenCalledWith( + expect.stringContaining('file-search-'), + ); + }); + + it('foreground CANCELLED prefixes the partial result so the parent sees the cancel', async () => { + // Without this prefix, a user-cancelled foreground subagent returns + // the same `{ llmContent: [{ text: finalText }] }` shape as a + // successful run, leaving the parent model unable to tell that the + // partial result is incomplete. The background path surfaces this + // through the registry's `cancelled` XML envelope; + // the foreground path has no equivalent envelope, so the marker + // rides the llmContent payload itself. + const fgSubagent: SubagentConfig = { + ...bgSubagent, + name: 'file-search', + background: undefined, + }; + vi.mocked(mockSubagentManager.loadSubagent).mockResolvedValue(fgSubagent); + vi.mocked(mockAgent.getFinalText).mockReturnValue('halfway through'); + vi.mocked(mockAgent.getTerminateMode).mockReturnValue( + AgentTerminateMode.CANCELLED, + ); + + const params: AgentParams = { + description: 'Search files', + prompt: 'Find all TypeScript files', + subagent_type: 'file-search', + }; + + const invocation = ( + agentTool as AgentToolWithProtectedMethods + ).createInvocation(params); + const result = await invocation.execute(); + + const llmText = partToString(result.llmContent); + expect(llmText).toContain('Agent was cancelled by the user.'); + expect(llmText).toContain('halfway through'); }); it('should allow background in non-interactive mode (headless support)', async () => { diff --git a/packages/core/src/tools/agent/agent.ts b/packages/core/src/tools/agent/agent.ts index 9d8a31345..d66376d8d 100644 --- a/packages/core/src/tools/agent/agent.ts +++ b/packages/core/src/tools/agent/agent.ts @@ -1156,6 +1156,7 @@ class AgentToolInvocation extends BaseToolInvocation { agentId: hookOpts.agentId, description: this.params.description, subagentType: subagentConfig.name, + flavor: 'background', status: 'running', startTime: Date.now(), abortController: bgAbortController, @@ -1334,20 +1335,98 @@ class AgentToolInvocation extends BaseToolInvocation { // Same agent-identity frame as the background path: a foreground // subagent can also launch nested agents, and those nested launches // need to see this subagent's id as their `parentAgentId`. - const runFramed = () => - runWithAgentContext({ agentId: hookOpts.agentId }, () => - this.runSubagentWithHooks(subagent, contextState, hookOpts), - ); if (isFork) { // Background fork execution. Run under an AsyncLocalStorage frame so // nested `agent` tool calls by the fork's model can be detected. - void runInForkContext(runFramed); + // Forks run async (return a placeholder); skip foreground registration. + const runFramedFork = () => + runWithAgentContext({ agentId: hookOpts.agentId }, () => + this.runSubagentWithHooks(subagent, contextState, hookOpts), + ); + void runInForkContext(runFramedFork); return { llmContent: [{ text: FORK_PLACEHOLDER_RESULT }], returnDisplay: this.currentDisplay!, }; + } + + // ── Foreground (synchronous) execution path ──────────────── + // Compose a child AbortController so the dialog's per-agent cancel + // can abort just this subagent without aborting the parent turn. + // Parent abort still propagates down (so ESC at the parent kills + // the subagent), but child abort does NOT propagate up. + const fgAbortController = new AbortController(); + const onParentAbort = () => fgAbortController.abort(); + if (signal?.aborted) { + fgAbortController.abort(); } else { + signal?.addEventListener('abort', onParentAbort, { once: true }); + } + + const fgHookOpts = { ...hookOpts, signal: fgAbortController.signal }; + const runFramed = () => + runWithAgentContext({ agentId: hookOpts.agentId }, () => + this.runSubagentWithHooks(subagent, contextState, fgHookOpts), + ); + + // Register in BackgroundTaskRegistry with flavor:'foreground' so the + // pill counts the run and the dialog can drill in. Foreground entries + // skip XML notification and headless-holdback (see the registry for + // the gating logic). + const registry = this.config.getBackgroundTaskRegistry(); + registry.register({ + agentId: hookOpts.agentId, + description: this.params.description, + subagentType: hookOpts.agentType, + flavor: 'foreground', + status: 'running', + startTime: Date.now(), + abortController: fgAbortController, + prompt: this.params.prompt, + toolUseId: this.callId, + }); + + // Mirror the background path's progress wiring so the dialog detail + // body has live tool-call activity AND a current `entry.stats` + // subtitle (`N tools · X tokens · Ys`). Without this, foreground + // entries collapse to elapsed-only in the dialog while background + // entries show full stats — strictly less information for the same + // runtime events. + // + // This is a separate listener from setupEventListeners' TOOL_CALL + // handler (which feeds `currentDisplay.toolCalls` for the committed + // inline frame). They consume different state — committed inline UI + // vs. live registry stats — and setupEventListeners runs before we + // know the flavor or the registry id, so folding them is awkward. + let fgLiveToolCallCount = 0; + const refreshFgLiveStats = () => { + const entry = registry.get(hookOpts.agentId); + if (!entry || entry.status !== 'running') return; + const summary = subagent.getExecutionSummary(); + entry.stats = { + totalTokens: summary.totalTokens, + toolUses: fgLiveToolCallCount, + durationMs: summary.totalDurationMs, + }; + }; + const onFgToolCall = (...args: unknown[]) => { + const event = args[0] as AgentToolCallEvent; + fgLiveToolCallCount += 1; + refreshFgLiveStats(); + registry.appendActivity(hookOpts.agentId, { + name: event.name, + description: event.description, + at: event.timestamp, + }); + }; + const onFgUsageMetadata = () => { + refreshFgLiveStats(); + }; + this.eventEmitter.on(AgentEventType.TOOL_CALL, onFgToolCall); + this.eventEmitter.on(AgentEventType.USAGE_METADATA, onFgUsageMetadata); + + try { await runFramed(); const finalText = subagent.getFinalText(); const terminateMode = subagent.getTerminateMode(); @@ -1357,10 +1436,39 @@ class AgentToolInvocation extends BaseToolInvocation { returnDisplay: this.currentDisplay!, }; } + if (terminateMode === AgentTerminateMode.CANCELLED) { + // Distinguish a user-cancelled run from a successful complete in + // the parent model's tool result. Without this prefix, a cancel + // collapses into the same `{ llmContent: [{ text: finalText }] }` + // shape as a successful run — the parent can't tell that the + // partial result is incomplete and may act on it as if the agent + // had finished. The background path surfaces this via the + // `cancelled` XML envelope; the foreground path + // has no equivalent envelope, so the marker has to ride the + // llmContent payload itself. + const partial = finalText || '(no partial result captured)'; + return { + llmContent: [ + { + text: `Agent was cancelled by the user. Partial result follows:\n\n${partial}`, + }, + ], + returnDisplay: this.currentDisplay!, + }; + } return { llmContent: [{ text: finalText }], returnDisplay: this.currentDisplay!, }; + } finally { + this.eventEmitter.off(AgentEventType.TOOL_CALL, onFgToolCall); + this.eventEmitter.off(AgentEventType.USAGE_METADATA, onFgUsageMetadata); + signal?.removeEventListener('abort', onParentAbort); + // Foreground entries leave the registry as soon as the tool-call + // returns — the parent's tool-result is the durable record. Doing + // this in finally guarantees we clean up on success, failure, + // cancel, AND any unexpected throw inside runFramed. + registry.unregisterForeground(hookOpts.agentId); } } catch (error) { const errorMessage = diff --git a/packages/core/src/tools/glob.test.ts b/packages/core/src/tools/glob.test.ts index 448953ee1..a8debb9a2 100644 --- a/packages/core/src/tools/glob.test.ts +++ b/packages/core/src/tools/glob.test.ts @@ -105,6 +105,13 @@ describe('GlobTool', () => { expect(result.llmContent).toContain(path.join(tempRootDir, 'fileA.txt')); expect(result.llmContent).toContain(path.join(tempRootDir, 'FileB.TXT')); expect(result.returnDisplay).toBe('Found 2 matching file(s)'); + expect(result.resultFilePaths).toHaveLength(2); + expect(result.resultFilePaths).toContain( + path.join(tempRootDir, 'fileA.txt'), + ); + expect(result.resultFilePaths).toContain( + path.join(tempRootDir, 'FileB.TXT'), + ); }); it('should find files case-insensitively by default (pattern: *.TXT)', async () => { diff --git a/packages/core/src/tools/glob.ts b/packages/core/src/tools/glob.ts index a352e8706..24cb661f8 100644 --- a/packages/core/src/tools/glob.ts +++ b/packages/core/src/tools/glob.ts @@ -276,6 +276,7 @@ class GlobToolInvocation extends BaseToolInvocation< return { llmContent: resultMessage, returnDisplay: `Found ${totalFileCount} matching file(s)${truncated ? ' (truncated)' : ''}`, + resultFilePaths: sortedAbsolutePaths, }; } catch (error) { const errorMessage = diff --git a/packages/core/src/tools/grep.test.ts b/packages/core/src/tools/grep.test.ts index 14da1223a..1c6278401 100644 --- a/packages/core/src/tools/grep.test.ts +++ b/packages/core/src/tools/grep.test.ts @@ -97,6 +97,9 @@ describe('GrepTool', () => { } as unknown as Config; beforeEach(async () => { + Object.assign(mockConfig, { + getTruncateToolOutputThreshold: () => 25000, + }); tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'grep-tool-root-')); grepTool = new GrepTool(mockConfig); @@ -207,6 +210,73 @@ describe('GrepTool', () => { ); expect(result.llmContent).toContain('L1: another world in sub dir'); expect(result.returnDisplay).toBe('Found 3 matches'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, 'fileA.txt'), + path.join(tempRootDir, 'sub', 'fileC.txt'), + ]); + }); + + it('normalizes CRLF fallback grep output without dropping result paths', () => { + const invocationForPrivateMethod = grepTool.build({ + pattern: 'world', + }) as unknown as { + parseGrepOutput: ( + output: string, + basePath: string, + ) => Array<{ absoluteFilePath: string; line: string }>; + }; + const filePath = path.join(tempRootDir, 'crlf.txt'); + + const matches = invocationForPrivateMethod.parseGrepOutput( + `crlf.txt:1:hello world\r${os.EOL}`, + tempRootDir, + ); + + expect(matches[0]).toMatchObject({ + absoluteFilePath: filePath, + line: 'hello world', + }); + }); + + it('includes result paths for partially rendered match lines', async () => { + Object.assign(mockConfig, { + getTruncateToolOutputThreshold: () => 22, + }); + await fs.writeFile( + path.join(tempRootDir, 'partial.ts'), + 'partial marker', + ); + + const invocation = grepTool.build({ pattern: 'marker', glob: '*.ts' }); + const result = await invocation.execute(abortSignal); + + expect(result.returnDisplay).toContain('truncated'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, 'partial.ts'), + ]); + }); + + it('only reports result paths for matches visible before character truncation', async () => { + Object.assign(mockConfig, { + getTruncateToolOutputThreshold: () => 30, + }); + await fs.writeFile(path.join(tempRootDir, 'a.ts'), 'visible marker'); + await fs.writeFile(path.join(tempRootDir, 'z.ts'), 'hidden marker'); + + const invocation = grepTool.build({ pattern: 'marker', glob: '*.ts' }); + const result = await invocation.execute(abortSignal); + + const allResultPaths = [ + path.join(tempRootDir, 'a.ts'), + path.join(tempRootDir, 'z.ts'), + ]; + expect(result.returnDisplay).toContain('truncated'); + expect(result.resultFilePaths?.length).toBeLessThan( + allResultPaths.length, + ); + for (const resultPath of result.resultFilePaths ?? []) { + expect(allResultPaths).toContain(resultPath); + } }); it('should find matches in a specific path', async () => { diff --git a/packages/core/src/tools/grep.ts b/packages/core/src/tools/grep.ts index f61f66bf0..2347c6ebb 100644 --- a/packages/core/src/tools/grep.ts +++ b/packages/core/src/tools/grep.ts @@ -62,6 +62,7 @@ export interface GrepToolParams { */ interface GrepMatch { filePath: string; + absoluteFilePath: string; lineNumber: number; line: string; } @@ -209,20 +210,38 @@ class GrepToolInvocation extends BaseToolInvocation< // Build grep output let grepOutput = ''; - for (const filePath in matchesByFile) { - grepOutput += `File: ${filePath}\n`; - matchesByFile[filePath].forEach((match) => { - const trimmedLine = match.line.trim(); - grepOutput += `L${match.lineNumber}: ${trimmedLine}\n`; - }); - grepOutput += '---\n'; - } - - // Apply character limit as safety net + const visibleMatches: GrepMatch[] = []; let truncatedByCharLimit = false; - if (Number.isFinite(charLimit) && grepOutput.length > charLimit) { - grepOutput = grepOutput.slice(0, charLimit) + '...'; - truncatedByCharLimit = true; + const appendChunk = (chunk: string, match?: GrepMatch): boolean => { + if ( + Number.isFinite(charLimit) && + grepOutput.length + chunk.length > charLimit + ) { + grepOutput += chunk.slice( + 0, + Math.max(charLimit - grepOutput.length, 0), + ); + grepOutput += '...'; + if (match) visibleMatches.push(match); + truncatedByCharLimit = true; + return false; + } + grepOutput += chunk; + if (match) visibleMatches.push(match); + return true; + }; + + for (const filePath in matchesByFile) { + if (!appendChunk(`File: ${filePath}\n`)) break; + let stopRendering = false; + for (const match of matchesByFile[filePath]) { + const trimmedLine = match.line.trim(); + if (!appendChunk(`L${match.lineNumber}: ${trimmedLine}\n`, match)) { + stopRendering = true; + break; + } + } + if (stopRendering || !appendChunk('---\n')) break; } // Count how many lines we actually included after character truncation @@ -252,6 +271,13 @@ class GrepToolInvocation extends BaseToolInvocation< return { llmContent: llmContent.trim(), returnDisplay: displayMessage, + resultFilePaths: Array.from( + new Set( + visibleMatches + .map((match) => match.absoluteFilePath) + .filter((filePath) => filePath !== ''), + ), + ), }; } catch (error) { debugLogger.error(`Error during GrepLogic execution: ${error}`); @@ -308,8 +334,9 @@ class GrepToolInvocation extends BaseToolInvocation< results.push({ filePath: relativeFilePath || path.basename(absoluteFilePath), + absoluteFilePath, lineNumber, - line: lineContent, + line: lineContent.replace(/\r$/, ''), }); } } @@ -531,6 +558,7 @@ class GrepToolInvocation extends BaseToolInvocation< filePath: path.relative(absolutePath, fileAbsolutePath) || path.basename(fileAbsolutePath), + absoluteFilePath: fileAbsolutePath, lineNumber: index + 1, line, }); diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index 140b78324..cbc927699 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -227,6 +227,214 @@ describe('McpClientManager', () => { expect(secondClient.disconnect).toHaveBeenCalledOnce(); }); + it('should coalesce concurrent discovery for the same server', async () => { + let resolveDisconnect!: () => void; + const disconnectPromise = new Promise((resolve) => { + resolveDisconnect = resolve; + }); + const firstClient = { + connect: vi.fn().mockResolvedValue(undefined), + discover: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn(() => disconnectPromise), + getStatus: vi.fn(), + }; + const replacementClients: Array<{ + connect: ReturnType; + discover: ReturnType; + disconnect: ReturnType; + getStatus: ReturnType; + }> = []; + + vi.mocked(McpClient).mockImplementation(() => { + if (vi.mocked(McpClient).mock.calls.length === 1) { + return firstClient as unknown as McpClient; + } + + const replacementClient = { + connect: vi.fn().mockResolvedValue(undefined), + discover: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + replacementClients.push(replacementClient); + return replacementClient as unknown as McpClient; + }); + + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + + const firstRediscovery = manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + await Promise.resolve(); + + const secondRediscovery = manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + const disconnectCallsBeforeResolve = + firstClient.disconnect.mock.calls.length; + + resolveDisconnect(); + await Promise.all([firstRediscovery, secondRediscovery]); + + expect(disconnectCallsBeforeResolve).toBe(1); + expect(vi.mocked(McpClient)).toHaveBeenCalledTimes(2); + expect(replacementClients).toHaveLength(1); + expect(replacementClients[0].connect).toHaveBeenCalledOnce(); + expect(replacementClients[0].discover).toHaveBeenCalledOnce(); + + // Verify map was cleaned up: a third call should do real work, + // not get coalesced into a stale promise. + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + + expect(vi.mocked(McpClient)).toHaveBeenCalledTimes(3); + expect(replacementClients).toHaveLength(2); + expect(replacementClients[1].connect).toHaveBeenCalledOnce(); + expect(replacementClients[1].discover).toHaveBeenCalledOnce(); + }); + + it('should restore health checks after failed server rediscovery', async () => { + vi.useFakeTimers(); + + const firstClient = { + connect: vi.fn().mockResolvedValue(undefined), + discover: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + const failedClient = { + connect: vi.fn().mockRejectedValue(new Error('transient failure')), + discover: vi.fn(), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + vi.mocked(McpClient) + .mockReturnValueOnce(firstClient as unknown as McpClient) + .mockReturnValueOnce(failedClient as unknown as McpClient); + + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager( + mockConfig, + {} as ToolRegistry, + undefined, + undefined, + { + autoReconnect: true, + checkIntervalMs: 10, + maxConsecutiveFailures: 1, + reconnectDelayMs: 10, + }, + ); + + try { + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + expect( + ( + manager as unknown as { + healthCheckTimers: Map; + } + ).healthCheckTimers.has('test-server'), + ).toBe(true); + + await manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + + expect(failedClient.connect).toHaveBeenCalledOnce(); + expect( + ( + manager as unknown as { + healthCheckTimers: Map; + } + ).healthCheckTimers.has('test-server'), + ).toBe(true); + } finally { + await manager.stop(); + vi.useRealTimers(); + } + }); + + it('should clear in-flight discovery tracking when stopping', async () => { + let resolveConnect!: () => void; + const connectPromise = new Promise((resolve) => { + resolveConnect = resolve; + }); + const mockedMcpClient = { + connect: vi.fn(() => connectPromise), + discover: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockResolvedValue(undefined), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + + const mockConfig = { + isTrustedFolder: () => true, + getMcpServers: () => ({ 'test-server': {} }), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({}) as PromptRegistry, + getWorkspaceContext: () => ({}) as WorkspaceContext, + getDebugMode: () => false, + } as unknown as Config; + const manager = new McpClientManager(mockConfig, {} as ToolRegistry); + + const discovery = manager.discoverMcpToolsForServer( + 'test-server', + {} as unknown as Config, + ); + await Promise.resolve(); + + expect( + ( + manager as unknown as { + serverDiscoveryPromises: Map>; + } + ).serverDiscoveryPromises.has('test-server'), + ).toBe(true); + + await manager.stop(); + + expect( + ( + manager as unknown as { + serverDiscoveryPromises: Map>; + } + ).serverDiscoveryPromises.has('test-server'), + ).toBe(false); + + resolveConnect(); + await discovery; + }); + it('should no-op when discovering an unknown server', async () => { const mockedMcpClient = { connect: vi.fn(), diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index ecc700739..885700abb 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -58,6 +58,7 @@ export class McpClientManager { private healthCheckTimers: Map = new Map(); private consecutiveFailures: Map = new Map(); private isReconnecting: Map = new Map(); + private serverDiscoveryPromises: Map> = new Map(); constructor( config: Config, @@ -147,6 +148,31 @@ export class McpClientManager { async discoverMcpToolsForServer( serverName: string, cliConfig: Config, + ): Promise { + const inProgressDiscovery = this.serverDiscoveryPromises.get(serverName); + if (inProgressDiscovery) { + await inProgressDiscovery; + return; + } + + const discoveryPromise = this.discoverMcpToolsForServerInternal( + serverName, + cliConfig, + ); + this.serverDiscoveryPromises.set(serverName, discoveryPromise); + + try { + await discoveryPromise; + } finally { + if (this.serverDiscoveryPromises.get(serverName) === discoveryPromise) { + this.serverDiscoveryPromises.delete(serverName); + } + } + } + + private async discoverMcpToolsForServerInternal( + serverName: string, + cliConfig: Config, ): Promise { const servers = populateMcpServerCommand( this.cliConfig.getMcpServers() || {}, @@ -157,6 +183,8 @@ export class McpClientManager { return; } + this.stopHealthCheck(serverName); + // Ensure we don't leak an existing connection for this server. const existingClient = this.clients.get(serverName); if (existingClient) { @@ -193,8 +221,6 @@ export class McpClientManager { try { await client.connect(); await client.discover(cliConfig); - // Start health check for this server after successful discovery - this.startHealthCheck(serverName); } catch (error) { // Log the error but don't throw: callers expect best-effort discovery. debugLogger.error( @@ -203,6 +229,7 @@ export class McpClientManager { )}`, ); } finally { + this.startHealthCheck(serverName); this.eventEmitter?.emit('mcp-client-update', this.clients); } } @@ -231,6 +258,7 @@ export class McpClientManager { this.clients.clear(); this.consecutiveFailures.clear(); this.isReconnecting.clear(); + this.serverDiscoveryPromises.clear(); } /** @@ -253,6 +281,7 @@ export class McpClientManager { this.clients.delete(serverName); this.consecutiveFailures.delete(serverName); this.isReconnecting.delete(serverName); + this.serverDiscoveryPromises.delete(serverName); this.eventEmitter?.emit('mcp-client-update', this.clients); } } diff --git a/packages/core/src/tools/ripGrep.test.ts b/packages/core/src/tools/ripGrep.test.ts index c02ccdee7..f3e9888a7 100644 --- a/packages/core/src/tools/ripGrep.test.ts +++ b/packages/core/src/tools/ripGrep.test.ts @@ -41,6 +41,7 @@ describe('RipGrepTool', () => { let grepTool: RipGrepTool; let fileExclusionsMock: { getGlobExcludes: () => string[] }; const abortSignal = new AbortController().signal; + const sep = '\x1f'; const mockConfig = { getTargetDir: () => tempRootDir, @@ -55,6 +56,9 @@ describe('RipGrepTool', () => { beforeEach(async () => { vi.clearAllMocks(); mockSpawn.mockReset(); + Object.assign(mockConfig, { + getTruncateToolOutputThreshold: () => 25000, + }); tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'grep-tool-root-')); fileExclusionsMock = { getGlobExcludes: vi.fn().mockReturnValue([]), @@ -160,7 +164,7 @@ describe('RipGrepTool', () => { describe('execute', () => { it('should find matches for a simple pattern in all files', async () => { (runRipgrep as Mock).mockResolvedValue({ - stdout: `fileA.txt:1:hello world${EOL}fileA.txt:2:second line with world${EOL}sub/fileC.txt:1:another world in sub dir${EOL}`, + stdout: `fileA.txt${sep}1${sep}hello world${EOL}fileA.txt${sep}2${sep}second line with world${EOL}sub/fileC.txt${sep}1${sep}another world in sub dir${EOL}`, truncated: false, error: undefined, }); @@ -177,6 +181,171 @@ describe('RipGrepTool', () => { 'sub/fileC.txt:1:another world in sub dir', ); expect(result.returnDisplay).toBe('Found 3 matches'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, 'fileA.txt'), + path.join(tempRootDir, 'sub/fileC.txt'), + ]); + }); + + it('should treat summary-only JSON output as no matches', async () => { + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${JSON.stringify({ type: 'summary', data: { stats: { matches: 0 } } })}${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'missing' }); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toBe( + 'No matches found for pattern "missing" in the workspace directory.', + ); + expect(result.returnDisplay).toBe('No matches found'); + }); + + it('parses JSON match events and records result paths', async () => { + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${JSON.stringify({ type: 'match', data: { path: { text: 'src/foo.ts' }, lines: { text: 'content\n' }, line_number: 5 } })}${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'content' }); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('src/foo.ts:5:content'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, 'src/foo.ts'), + ]); + }); + + it('parses JSON match events with byte-encoded paths', async () => { + const bytePath = 'src/byte-path.ts'; + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${JSON.stringify({ type: 'match', data: { path: { bytes: Buffer.from(bytePath, 'utf8').toString('base64') }, lines: { text: 'content\n' }, line_number: 3 } })}${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'content' }); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('src/byte-path.ts:3:content'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, bytePath), + ]); + }); + + it('handles JSON match events without a lines field', async () => { + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${JSON.stringify({ type: 'match', data: { path: { text: 'fileA.txt' }, line_number: 1 } })}${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'hello' }); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('fileA.txt:1:'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, 'fileA.txt'), + ]); + }); + + it('surfaces ripgrep system-level truncation in display metadata', async () => { + (runRipgrep as Mock).mockResolvedValue({ + stdout: `fileA.txt${sep}1${sep}hello world${EOL}`, + truncated: true, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'hello' }); + const result = await invocation.execute(abortSignal); + + expect(result.returnDisplay).toBe('Found 1 match (truncated)'); + expect(result.llmContent).toContain('[0 lines truncated] ...'); + }); + + it('should preserve absolute result paths reported by ripgrep', async () => { + const absoluteMatchPath = path.join( + tempRootDir, + 'packages/core/src/skills/target.ts', + ); + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${absoluteMatchPath}${sep}1${sep}CORE_HELPER_TARGET_MARKER${EOL}`, + truncated: false, + error: undefined, + }); + + const params: RipGrepToolParams = { + pattern: 'CORE_HELPER_TARGET_MARKER', + glob: '**/*.ts', + }; + const invocation = grepTool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.resultFilePaths).toEqual([absoluteMatchPath]); + }); + + it('should parse Windows-style absolute result paths reported by ripgrep', async () => { + const absoluteMatchPath = + 'C:\\repo\\packages\\core\\src\\skills\\target.ts'; + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${absoluteMatchPath}${sep}12${sep}CORE_HELPER_TARGET_MARKER${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ + pattern: 'CORE_HELPER_TARGET_MARKER', + glob: '**/*.ts', + }); + const result = await invocation.execute(abortSignal); + + expect(result.resultFilePaths).toEqual([absoluteMatchPath]); + }); + + it('includes result paths for partially rendered long file paths', async () => { + Object.assign(mockConfig, { + getTruncateToolOutputThreshold: () => 30, + }); + const longPath = 'packages/core/src/skills/very-long-named-file.ts'; + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${longPath}${sep}1${sep}visible marker${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'marker', glob: '**/*.ts' }); + const result = await invocation.execute(abortSignal); + + expect(result.returnDisplay).toContain('truncated'); + expect(result.llmContent).toContain('packages/core/src/skills/very'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, longPath), + ]); + }); + + it('only reports result paths for lines reached before character truncation', async () => { + Object.assign(mockConfig, { + getTruncateToolOutputThreshold: () => 25, + }); + const visiblePath = 'a.ts'; + const hiddenPath = 'hidden-file-with-long-name.ts'; + (runRipgrep as Mock).mockResolvedValue({ + stdout: `${visiblePath}${sep}1${sep}visible marker${EOL}${hiddenPath}${sep}1${sep}hidden marker${EOL}`, + truncated: false, + error: undefined, + }); + + const invocation = grepTool.build({ pattern: 'marker', glob: '**/*.ts' }); + const result = await invocation.execute(abortSignal); + + expect(result.returnDisplay).toContain('truncated'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, visiblePath), + path.join(tempRootDir, hiddenPath), + ]); }); it('should find matches in a specific path', async () => { @@ -470,7 +639,7 @@ describe('RipGrepTool', () => { const multiDirGrepTool = new RipGrepTool(multiDirConfig); (runRipgrep as Mock).mockResolvedValue({ - stdout: `fileA.txt:1:hello world${EOL}${secondDir}/extra.txt:1:hello from second dir${EOL}`, + stdout: `fileA.txt${sep}1${sep}hello world${EOL}${secondDir}${path.sep}extra.txt${sep}1${sep}hello from second dir${EOL}`, truncated: false, error: undefined, }); @@ -481,10 +650,19 @@ describe('RipGrepTool', () => { expect(result.llmContent).toContain('across 2 workspace directories'); expect(result.llmContent).toContain('Found 2 matches'); + expect(result.resultFilePaths).toEqual([ + path.join(tempRootDir, 'fileA.txt'), + path.join(secondDir, 'extra.txt'), + ]); // Verify both paths were passed to runRipgrep expect(runRipgrep).toHaveBeenCalledWith( - expect.arrayContaining([tempRootDir, secondDir]), + expect.arrayContaining([ + '--json', + '--no-messages', + tempRootDir, + secondDir, + ]), expect.anything(), ); @@ -575,7 +753,7 @@ describe('RipGrepTool', () => { const multiDirGrepTool = new RipGrepTool(multiDirConfig); // Simulate ripgrep returning the same file:line twice (once from each search root) - const dupLine = `${path.join(subDir, 'fileC.txt')}:1:hello world`; + const dupLine = `${path.join(subDir, 'fileC.txt')}${sep}1${sep}hello world`; (runRipgrep as Mock).mockResolvedValue({ stdout: `${dupLine}${EOL}${dupLine}${EOL}`, truncated: false, diff --git a/packages/core/src/tools/ripGrep.ts b/packages/core/src/tools/ripGrep.ts index 1c218b019..eddb46688 100644 --- a/packages/core/src/tools/ripGrep.ts +++ b/packages/core/src/tools/ripGrep.ts @@ -20,6 +20,44 @@ import { createDebugLogger } from '../utils/debugLogger.js'; import type { PermissionDecision } from '../permissions/types.js'; const debugLogger = createDebugLogger('RIPGREP'); +const RIPGREP_FIELD_SEPARATOR = ''; + +interface RipgrepJsonMatch { + type: 'match'; + data: { + path: { text?: string; bytes?: string }; + lines?: { text?: string }; + line_number: number; + }; +} + +function isRipgrepJsonMatch(value: unknown): value is RipgrepJsonMatch { + if (typeof value !== 'object' || value === null) return false; + const candidate = value as { + type?: unknown; + data?: { + path?: { text?: unknown; bytes?: unknown }; + lines?: { text?: unknown }; + line_number?: unknown; + }; + }; + return ( + candidate.type === 'match' && + (typeof candidate.data?.path?.text === 'string' || + typeof candidate.data?.path?.bytes === 'string') && + typeof candidate.data?.line_number === 'number' + ); +} + +function getRipgrepJsonPath(match: RipgrepJsonMatch): string | undefined { + if (match.data.path.text !== undefined) { + return match.data.path.text; + } + if (match.data.path.bytes !== undefined) { + return Buffer.from(match.data.path.bytes, 'base64').toString('utf8'); + } + return undefined; +} /** * Per-process cache for `.qwenignore` discovery. The same directories show @@ -44,6 +82,19 @@ function trimCache(m: Map): void { if (oldest !== undefined) m.delete(oldest as K); } +function toAbsoluteResultPath(filePath: string, searchPaths: string[]): string { + if (path.isAbsolute(filePath) || path.win32.isAbsolute(filePath)) { + return filePath; + } + for (const searchPath of searchPaths) { + const candidate = path.resolve(searchPath, filePath); + if (fs.existsSync(candidate)) { + return candidate; + } + } + return path.resolve(searchPaths[0], filePath); +} + /** * Test-only: clear ripGrep's module-level discovery caches between cases. */ @@ -132,12 +183,13 @@ class GrepToolInvocation extends BaseToolInvocation< } // Get raw ripgrep output - const rawOutput = await this.performRipgrepSearch({ - pattern: this.params.pattern, - paths: searchPaths, - glob: this.params.glob, - signal, - }); + const { stdout: rawOutput, truncated: truncatedBySystemLimit } = + await this.performRipgrepSearch({ + pattern: this.params.pattern, + paths: searchPaths, + glob: this.params.glob, + signal, + }); // Build search description const searchLocationDescription = this.params.path @@ -156,29 +208,81 @@ class GrepToolInvocation extends BaseToolInvocation< return { llmContent: noMatchMsg, returnDisplay: `No matches found` }; } - // Split into lines and count total matches - let allLines = rawOutput.split('\n').filter((line) => line.trim()); + interface RipgrepMatchLine { + rawLine: string; + filePath: string; + key: string; + } + + let allLines = rawOutput + .split('\n') + .filter((line) => line.trim()) + .flatMap((line): RipgrepMatchLine[] => { + if (line.startsWith('{')) { + if (!line.startsWith('{"type":"match"')) return []; + try { + const parsed = JSON.parse(line) as unknown; + if (!isRipgrepJsonMatch(parsed)) return []; + const filePath = getRipgrepJsonPath(parsed); + if (filePath === undefined) return []; + const lineNumber = String(parsed.data.line_number); + const content = parsed.data.lines?.text ?? ''; + return [ + { + rawLine: `${filePath}:${lineNumber}:${content.replace(/\r?\n$/, '')}`, + filePath, + key: `${filePath}:${lineNumber}`, + }, + ]; + } catch { + return []; + } + } + + const fields = line.split(RIPGREP_FIELD_SEPARATOR); + if (fields.length === 1) { + const firstColon = line.indexOf(':'); + const secondColon = + firstColon === -1 ? -1 : line.indexOf(':', firstColon + 1); + if (firstColon === -1 || secondColon === -1) return []; + const filePath = line.substring(0, firstColon); + const lineNumber = line.substring(firstColon + 1, secondColon); + if (!/^[0-9]+$/.test(lineNumber)) return []; + return [ + { + rawLine: line, + filePath, + key: `${filePath}:${lineNumber}`, + }, + ]; + } + if (fields.length !== 3) return []; + const [filePath, lineNumber, content] = fields; + return [ + { + rawLine: `${filePath}:${lineNumber}:${content}`, + filePath, + key: `${filePath}:${lineNumber}`, + }, + ]; + }); // Deduplicate lines from potentially overlapping workspace directories. // ripgrep reports the same file twice when given paths like /a and /a/sub. if (searchPaths.length > 1) { const seen = new Set(); allLines = allLines.filter((line) => { - // ripgrep output format: filepath:linenum:content - const firstColon = line.indexOf(':'); - if (firstColon !== -1) { - const secondColon = line.indexOf(':', firstColon + 1); - if (secondColon !== -1) { - const key = line.substring(0, secondColon); - if (seen.has(key)) return false; - seen.add(key); - } - } + if (seen.has(line.key)) return false; + seen.add(line.key); return true; }); } const totalMatches = allLines.length; + if (totalMatches === 0) { + const noMatchMsg = `No matches found for pattern "${this.params.pattern}" ${searchLocationDescription}${filterDescription}.`; + return { llmContent: noMatchMsg, returnDisplay: `No matches found` }; + } const matchTerm = totalMatches === 1 ? 'match' : 'matches'; // Build header early to calculate available space @@ -202,21 +306,24 @@ class GrepToolInvocation extends BaseToolInvocation< let grepOutput = ''; let truncatedByCharLimit = false; let includedLines = 0; + const visibleLines: RipgrepMatchLine[] = []; if (Number.isFinite(charLimit)) { const parts: string[] = []; let currentLength = 0; for (const line of linesToInclude) { const sep = includedLines > 0 ? 1 : 0; - includedLines++; - - const projectedLength = currentLength + line.length + sep; + const projectedLength = currentLength + line.rawLine.length + sep; if (projectedLength <= charLimit) { - parts.push(line); + parts.push(line.rawLine); + visibleLines.push(line); + includedLines++; currentLength = projectedLength; } else { const remaining = Math.max(charLimit - currentLength - sep, 10); - parts.push(line.slice(0, remaining) + '...'); + const partialLine = line.rawLine.slice(0, remaining); + parts.push(partialLine + '...'); + visibleLines.push(line); truncatedByCharLimit = true; break; } @@ -224,7 +331,8 @@ class GrepToolInvocation extends BaseToolInvocation< grepOutput = parts.join('\n'); } else { - grepOutput = linesToInclude.join('\n'); + grepOutput = linesToInclude.map((line) => line.rawLine).join('\n'); + visibleLines.push(...linesToInclude); includedLines = linesToInclude.length; } @@ -232,20 +340,37 @@ class GrepToolInvocation extends BaseToolInvocation< let llmContent = header + grepOutput; // Add truncation notice if needed - if (truncatedByLineLimit || truncatedByCharLimit) { + if ( + truncatedByLineLimit || + truncatedByCharLimit || + truncatedBySystemLimit + ) { const omittedMatches = totalMatches - includedLines; llmContent += `\n---\n[${omittedMatches} ${omittedMatches === 1 ? 'line' : 'lines'} truncated] ...`; } // Build display message (show real count, not truncated) let displayMessage = `Found ${totalMatches} ${matchTerm}`; - if (truncatedByLineLimit || truncatedByCharLimit) { + if ( + truncatedByLineLimit || + truncatedByCharLimit || + truncatedBySystemLimit + ) { displayMessage += ` (truncated)`; } + const resultFilePaths = Array.from( + new Set( + visibleLines.map((line) => + toAbsoluteResultPath(line.filePath, searchPaths), + ), + ), + ); + return { llmContent: llmContent.trim(), returnDisplay: displayMessage, + resultFilePaths, }; } catch (error) { debugLogger.error('Error during ripgrep search operation:', error); @@ -262,13 +387,14 @@ class GrepToolInvocation extends BaseToolInvocation< paths: string[]; // Can be files or directories glob?: string; signal: AbortSignal; - }): Promise { + }): Promise<{ stdout: string; truncated: boolean }> { const { pattern, paths, glob } = options; const rgArgs: string[] = [ - '--line-number', - '--no-heading', - '--with-filename', + '--json', + '--no-messages', + '--path-separator', + '/', '--ignore-case', '--regexp', pattern, @@ -323,7 +449,7 @@ class GrepToolInvocation extends BaseToolInvocation< throw result.error; } - return result.stdout; + return { stdout: result.stdout, truncated: result.truncated }; } private getFileFilteringOptions(): FileFilteringOptions { diff --git a/packages/core/src/tools/task-stop.test.ts b/packages/core/src/tools/task-stop.test.ts index d6d8fc4d8..ed2c500d5 100644 --- a/packages/core/src/tools/task-stop.test.ts +++ b/packages/core/src/tools/task-stop.test.ts @@ -25,11 +25,19 @@ describe('TaskStopTool', () => { abandonBackgroundAgent = vi.fn(); shellRegistry = new BackgroundShellRegistry(); monitorRegistry = new MonitorRegistry(); + // Default fake MemoryManager — every test that doesn't care about + // dream gets an empty stub so the 4th-route lookup falls through to + // the not-found branch instead of crashing on undefined. + const memoryManager = { + getTask: vi.fn(() => undefined), + cancelTask: vi.fn(() => false), + }; config = { getBackgroundTaskRegistry: () => registry, abandonBackgroundAgent, getBackgroundShellRegistry: () => shellRegistry, getMonitorRegistry: () => monitorRegistry, + getMemoryManager: () => memoryManager, } as unknown as Config; tool = new TaskStopTool(config); }); @@ -268,4 +276,156 @@ describe('TaskStopTool', () => { expect(result.llmContent).toContain('completed'); }); }); + + describe('dream task support', () => { + it('cancels a running dream by routing through MemoryManager.cancelTask', async () => { + const cancelTask = vi.fn(() => true); + const dreamRecord = { + id: 'dream-running-1', + taskType: 'dream' as const, + projectRoot: '/p', + status: 'running' as const, + createdAt: '2026-05-04T12:00:00.000Z', + updatedAt: '2026-05-04T12:00:00.000Z', + }; + const memoryManager = { + getTask: vi.fn((id: string) => + id === 'dream-running-1' ? dreamRecord : undefined, + ), + cancelTask, + }; + const localConfig = { + getBackgroundTaskRegistry: () => registry, + abandonBackgroundAgent, + getBackgroundShellRegistry: () => shellRegistry, + getMonitorRegistry: () => monitorRegistry, + getMemoryManager: () => memoryManager, + } as unknown as Config; + const localTool = new TaskStopTool(localConfig); + + const result = await localTool.validateBuildAndExecute( + { task_id: 'dream-running-1' }, + new AbortController().signal, + ); + + expect(cancelTask).toHaveBeenCalledWith('dream-running-1'); + expect(result.error).toBeUndefined(); + expect(result.llmContent).toContain('Cancellation requested'); + expect(result.llmContent).toContain('dream task "dream-running-1"'); + }); + + it('returns NOT_RUNNING when the dream is already terminal', async () => { + // Mirrors the agent / shell / monitor not-running guards so a + // model retry against an already-finished dream surfaces the + // distinct error type instead of "not found". + const dreamRecord = { + id: 'dream-done-1', + taskType: 'dream' as const, + projectRoot: '/p', + status: 'completed' as const, + createdAt: '2026-05-04T12:00:00.000Z', + updatedAt: '2026-05-04T12:01:00.000Z', + }; + const cancelTask = vi.fn(() => false); + const memoryManager = { + getTask: vi.fn(() => dreamRecord), + cancelTask, + }; + const localConfig = { + getBackgroundTaskRegistry: () => registry, + abandonBackgroundAgent, + getBackgroundShellRegistry: () => shellRegistry, + getMonitorRegistry: () => monitorRegistry, + getMemoryManager: () => memoryManager, + } as unknown as Config; + const localTool = new TaskStopTool(localConfig); + + const result = await localTool.validateBuildAndExecute( + { task_id: 'dream-done-1' }, + new AbortController().signal, + ); + + expect(cancelTask).not.toHaveBeenCalled(); + expect(result.error?.type).toBe(ToolErrorType.TASK_STOP_NOT_RUNNING); + expect(result.llmContent).toContain('Background dream "dream-done-1"'); + expect(result.llmContent).toContain('completed'); + }); + + it('returns NOT_CANCELLABLE when the task id resolves to an extract record', async () => { + // Extract is short-lived and runs on the request path; cancelling + // it would interfere with the user's own turn. The dispatch must + // distinguish "task exists but isn't cancellable" from "task + // doesn't exist" — without the distinct error type, a model + // retrying against an extract id would incorrectly conclude the + // id was never valid. + const extractRecord = { + id: 'extract-running-1', + taskType: 'extract' as const, + projectRoot: '/p', + status: 'running' as const, + createdAt: '2026-05-04T12:00:00.000Z', + updatedAt: '2026-05-04T12:00:00.000Z', + }; + const cancelTask = vi.fn(); + const memoryManager = { + getTask: vi.fn(() => extractRecord), + cancelTask, + }; + const localConfig = { + getBackgroundTaskRegistry: () => registry, + abandonBackgroundAgent, + getBackgroundShellRegistry: () => shellRegistry, + getMonitorRegistry: () => monitorRegistry, + getMemoryManager: () => memoryManager, + } as unknown as Config; + const localTool = new TaskStopTool(localConfig); + + const result = await localTool.validateBuildAndExecute( + { task_id: 'extract-running-1' }, + new AbortController().signal, + ); + + expect(cancelTask).not.toHaveBeenCalled(); + expect(result.error?.type).toBe(ToolErrorType.TASK_STOP_NOT_CANCELLABLE); + expect(result.llmContent).toContain('extract'); + expect(result.llmContent).toContain('not cancellable'); + }); + + it('returns an error when cancelTask returns false (missing AbortController)', async () => { + // The MemoryManager.cancelTask contract returns false when the + // AbortController is missing for a running record — a logic- + // level invariant violation. task_stop must surface the failure + // rather than report a phantom success, otherwise the model + // believes the dream is being aborted while it actually keeps + // burning tokens. + const dreamRecord = { + id: 'dream-broken-1', + taskType: 'dream' as const, + projectRoot: '/p', + status: 'running' as const, + createdAt: '2026-05-04T12:00:00.000Z', + updatedAt: '2026-05-04T12:00:00.000Z', + }; + const memoryManager = { + getTask: vi.fn(() => dreamRecord), + cancelTask: vi.fn(() => false), + }; + const localConfig = { + getBackgroundTaskRegistry: () => registry, + abandonBackgroundAgent, + getBackgroundShellRegistry: () => shellRegistry, + getMonitorRegistry: () => monitorRegistry, + getMemoryManager: () => memoryManager, + } as unknown as Config; + const localTool = new TaskStopTool(localConfig); + + const result = await localTool.validateBuildAndExecute( + { task_id: 'dream-broken-1' }, + new AbortController().signal, + ); + + expect(result.error?.type).toBe(ToolErrorType.TASK_STOP_INTERNAL_ERROR); + expect(result.llmContent).toContain('could not be cancelled'); + }); + }); }); diff --git a/packages/core/src/tools/task-stop.ts b/packages/core/src/tools/task-stop.ts index da9df034b..bbe758062 100644 --- a/packages/core/src/tools/task-stop.ts +++ b/packages/core/src/tools/task-stop.ts @@ -136,6 +136,67 @@ class TaskStopInvocation extends BaseToolInvocation< }; } + // MemoryManager memory tasks (dream + extract). Memory tasks live + // outside the registry trio (MemoryManager owns its own task map). + // Only `dream` is cancellable — extract is short-lived and runs on + // the request loop, so cancelling it would interfere with the + // user's own turn. Surface a distinct error for known-but-not- + // cancellable records so the model doesn't conclude the id was + // never valid (which would happen if we fell through to NOT_FOUND). + const memoryManager = this.config.getMemoryManager(); + const memoryRecord = memoryManager.getTask(taskId); + if (memoryRecord) { + if (memoryRecord.taskType !== 'dream') { + return { + llmContent: + `Error: Memory task "${taskId}" (${memoryRecord.taskType}) is ` + + `not cancellable. Only dream consolidation tasks support ` + + `cancellation; extract tasks run on the request loop and ` + + `complete in milliseconds.`, + returnDisplay: `Task not cancellable (${memoryRecord.taskType}).`, + error: { + message: `task is not cancellable: ${taskId} (${memoryRecord.taskType})`, + type: ToolErrorType.TASK_STOP_NOT_CANCELLABLE, + }, + }; + } + if (memoryRecord.status !== 'running') { + return notRunningError('dream', taskId, memoryRecord.status); + } + // cancelTask returns false if the AbortController is missing for + // a running record (logic-level invariant violation; see + // MemoryManager.cancelTask). Surface that explicitly so the model + // sees the cancel didn't take and doesn't claim success. + const cancelled = memoryManager.cancelTask(taskId); + if (!cancelled) { + // Distinct from TASK_STOP_NOT_RUNNING (the task IS running) + // and TASK_STOP_NOT_CANCELLABLE (the kind supports cancel, + // we just couldn't deliver it). INTERNAL_ERROR signals that + // this is unexpected and worth filing — the abort controller + // should have been registered alongside status='running' in + // scheduleDream. + return { + llmContent: + `Error: Dream task "${taskId}" could not be cancelled ` + + `(internal state inconsistency — abort controller missing).`, + returnDisplay: 'Dream cancellation failed (internal state).', + error: { + message: `dream cancel failed: ${taskId}`, + type: ToolErrorType.TASK_STOP_INTERNAL_ERROR, + }, + }; + } + return { + llmContent: + `Cancellation requested for dream task "${taskId}". ` + + `The fork agent is being aborted; the consolidation lock will ` + + `be released as the agent unwinds. Status is visible via the ` + + `interactive Background tasks dialog (focus the footer Background ` + + `tasks pill, then Enter).`, + returnDisplay: `Cancelled dream: ${taskId}`, + }; + } + return { llmContent: `Error: No background task found with ID "${taskId}".`, returnDisplay: 'Task not found.', @@ -148,7 +209,7 @@ class TaskStopInvocation extends BaseToolInvocation< } function notRunningError( - kind: 'agent' | 'shell' | 'monitor', + kind: 'agent' | 'shell' | 'monitor' | 'dream', taskId: string, status: string, ): ToolResult { diff --git a/packages/core/src/tools/tool-error.ts b/packages/core/src/tools/tool-error.ts index 6a29472c6..10023d502 100644 --- a/packages/core/src/tools/tool-error.ts +++ b/packages/core/src/tools/tool-error.ts @@ -70,6 +70,8 @@ export enum ToolErrorType { // TaskStop-specific Errors TASK_STOP_NOT_FOUND = 'task_stop_not_found', TASK_STOP_NOT_RUNNING = 'task_stop_not_running', + TASK_STOP_NOT_CANCELLABLE = 'task_stop_not_cancellable', + TASK_STOP_INTERNAL_ERROR = 'task_stop_internal_error', // SendMessage-specific Errors SEND_MESSAGE_NOT_FOUND = 'send_message_not_found', diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 760c38daf..65ec2fcef 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -388,6 +388,12 @@ export interface ToolResult { */ returnDisplay: ToolResultDisplay; + /** + * Concrete filesystem paths discovered or touched during successful execution. + * Scheduler-side path activation consumes these in addition to input fields. + */ + resultFilePaths?: string[]; + /** * If this property is present, the tool call is considered a failure. */ diff --git a/packages/sdk-python/scripts/get-release-version.js b/packages/sdk-python/scripts/get-release-version.js index 28c32d5ec..ca5b977da 100644 --- a/packages/sdk-python/scripts/get-release-version.js +++ b/packages/sdk-python/scripts/get-release-version.js @@ -20,7 +20,9 @@ const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); const PACKAGE_NAME = 'qwen-code-sdk'; -const TAG_PREFIX = 'sdk-python-'; +const TAG_PREFIX = 'sdk-python-v'; +const NETWORK_COMMAND_TIMEOUT_MS = 30_000; +const LOCAL_COMMAND_TIMEOUT_MS = 10_000; function readPyprojectVersion() { const pyprojectPath = join(__dirname, '..', 'pyproject.toml'); @@ -117,7 +119,7 @@ function toBaseVersion(version) { async function getAllVersionsFromPyPI() { const response = await fetch(`https://pypi.org/pypi/${PACKAGE_NAME}/json`, { headers: { Accept: 'application/json' }, - signal: AbortSignal.timeout(30_000), + signal: AbortSignal.timeout(NETWORK_COMMAND_TIMEOUT_MS), }); if (response.status === 404) { @@ -241,28 +243,67 @@ function getUtcTimestamp() { } function getGitShortHash() { - return execSync('git rev-parse --short HEAD').toString().trim(); + try { + return execSync('git rev-parse --short HEAD', { + timeout: LOCAL_COMMAND_TIMEOUT_MS, + }) + .toString() + .trim(); + } catch (error) { + if (isTimeoutError(error)) { + throw new Error( + `git rev-parse timed out after ${LOCAL_COMMAND_TIMEOUT_MS / 1000}s — local git may be unresponsive`, + ); + } + throw error; + } } -async function getReleaseState({ packageVersion, releaseTag }, allVersions) { +function isTimeoutError(error) { + // Node.js execSync timeout: `code` is 'ETIMEDOUT' on POSIX; on some + // versions/platforms `killed` is true with signal 'SIGTERM' or null. + // Match the pattern used in packages/core/src/utils/pdf.ts. + return ( + error.code === 'ETIMEDOUT' || + (error.killed === true && + (error.signal === 'SIGTERM' || + error.signal === undefined || + error.signal === null)) + ); +} + +async function getReleaseState( + { packageVersion, releaseVersion }, + allVersions, +) { const state = { packageVersionExistsOnPyPI: allVersions.includes(packageVersion), gitTagExists: false, githubReleaseExists: false, }; - const fullTag = `${TAG_PREFIX}${releaseTag}`; + const fullTag = `${TAG_PREFIX}${releaseVersion}`; try { - const tagOutput = execSync(`git tag -l '${fullTag}'`).toString().trim(); + const tagOutput = execSync(`git tag -l '${fullTag}'`, { + timeout: LOCAL_COMMAND_TIMEOUT_MS, + }) + .toString() + .trim(); if (tagOutput === fullTag) { state.gitTagExists = true; } } catch (error) { + if (isTimeoutError(error)) { + throw new Error( + `git tag -l timed out after ${LOCAL_COMMAND_TIMEOUT_MS / 1000}s — local git may be unresponsive`, + ); + } throw new Error(`Failed to check git tags for conflicts: ${error.message}`); } try { const output = execSync( `gh release view "${fullTag}" --json tagName --jq .tagName`, + { timeout: NETWORK_COMMAND_TIMEOUT_MS }, ) .toString() .trim(); @@ -270,6 +311,13 @@ async function getReleaseState({ packageVersion, releaseTag }, allVersions) { state.githubReleaseExists = true; } } catch (error) { + // Timeout check must precede isExpectedMissingGitHubRelease — a timed-out + // process may emit partial stderr matching "release not found". + if (isTimeoutError(error)) { + throw new Error( + `gh release view timed out after ${NETWORK_COMMAND_TIMEOUT_MS / 1000}s checking "${fullTag}" — GitHub API may be unavailable`, + ); + } if (!isExpectedMissingGitHubRelease(error)) { throw new Error( `Failed to check GitHub releases for conflicts: ${error.message}`, @@ -435,7 +483,7 @@ async function getVersion(options = {}) { const releaseState = await getReleaseState( { packageVersion: versionData.packageVersion, - releaseTag: `v${versionData.releaseVersion}`, + releaseVersion: versionData.releaseVersion, }, allVersions, ); @@ -481,11 +529,11 @@ async function getVersion(options = {}) { if (releaseState.githubReleaseExists) { console.error( - `GitHub release ${TAG_PREFIX}v${versionData.releaseVersion} already exists.`, + `GitHub release ${TAG_PREFIX}${versionData.releaseVersion} already exists.`, ); } else if (releaseState.gitTagExists) { console.error( - `::warning::Orphan git tag ${TAG_PREFIX}v${versionData.releaseVersion} exists without a PyPI version or GitHub release. Skipping to next version slot.`, + `::warning::Orphan git tag ${TAG_PREFIX}${versionData.releaseVersion} exists without a PyPI version or GitHub release. Skipping to next version slot.`, ); } else if (releaseState.packageVersionExistsOnPyPI) { console.error( diff --git a/scripts/tests/get-release-version-python-sdk.test.js b/scripts/tests/get-release-version-python-sdk.test.js index c8b170adf..9c7976bd5 100644 --- a/scripts/tests/get-release-version-python-sdk.test.js +++ b/scripts/tests/get-release-version-python-sdk.test.js @@ -50,6 +50,16 @@ function makeExecError(message, { stderr = '', stdout = '', status } = {}) { return error; } +function makeTimeoutError(command) { + const error = new Error(`Command failed: ${command}\nSIGTERM`); + // Real Node.js execSync timeout shape (verified on Node 20+): + // killed=undefined, signal='SIGTERM', code='ETIMEDOUT' + error.code = 'ETIMEDOUT'; + error.signal = 'SIGTERM'; + error.status = null; + return error; +} + function makeExecSyncMock({ tags = {}, tagError = null, @@ -927,4 +937,51 @@ describe('python sdk get-release-version', () => { resumeExistingRelease: true, }); }); + + it('throws a timeout error when gh release view times out', async () => { + execSyncMock.mockImplementation( + makeExecSyncMock({ + releases: { + 'sdk-python-v0.1.0-preview.0': makeTimeoutError( + 'gh release view "sdk-python-v0.1.0-preview.0"', + ), + }, + }), + ); + + const getVersion = await loadGetVersion(); + + await expect(getVersion({ type: 'preview' })).rejects.toThrow( + 'gh release view timed out after 30s checking "sdk-python-v0.1.0-preview.0" — GitHub API may be unavailable', + ); + }); + + it('throws a timeout error when git tag -l times out', async () => { + execSyncMock.mockImplementation( + makeExecSyncMock({ + tagError: makeTimeoutError('git tag -l'), + }), + ); + + const getVersion = await loadGetVersion(); + + await expect(getVersion({ type: 'preview' })).rejects.toThrow( + 'git tag -l timed out after 10s — local git may be unresponsive', + ); + }); + + it('throws a timeout error when git rev-parse times out', async () => { + execSyncMock.mockImplementation((command) => { + if (command === 'git rev-parse --short HEAD') { + throw makeTimeoutError('git rev-parse --short HEAD'); + } + return makeExecSyncMock()(command); + }); + + const getVersion = await loadGetVersion(); + + await expect(getVersion({ type: 'nightly' })).rejects.toThrow( + 'git rev-parse timed out after 10s — local git may be unresponsive', + ); + }); });