diff --git a/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts b/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts index e28e9046d..d4566fcf3 100644 --- a/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts +++ b/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts @@ -16,7 +16,11 @@ import { type ContentBlock, type SDKUserMessage, } from '@qwen-code/sdk'; -import { SDKTestHelper, createSharedTestOptions } from './test-helper.js'; +import { + SDKTestHelper, + createSharedTestOptions, + createResultWaiter, +} from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); @@ -254,6 +258,12 @@ describe('AbortController and Process Lifecycle (E2E)', () => { describe('Closed stdin behavior (asyncGenerator prompt)', () => { it('should reject control requests after stdin closes', async () => { + const resultWaiter = createResultWaiter(1); + let promptDoneResolve: () => void = () => {}; + const promptDonePromise = new Promise((resolve) => { + promptDoneResolve = resolve; + }); + async function* createPrompt(): AsyncIterable { yield { type: 'user', @@ -264,6 +274,9 @@ describe('AbortController and Process Lifecycle (E2E)', () => { }, parent_tool_use_id: null, }; + + await resultWaiter.waitForResult(0); + promptDoneResolve(); } const q = query({ @@ -281,13 +294,14 @@ describe('AbortController and Process Lifecycle (E2E)', () => { for await (const message of q) { if (isSDKResultMessage(message)) { firstResultReceived = true; + resultWaiter.notifyResult(); break; } } expect(firstResultReceived).toBe(true); - - await new Promise((resolve) => setTimeout(resolve, 50)); + await promptDonePromise; + q.endInput(); await expect(q.setPermissionMode('default')).rejects.toThrow( 'Input stream closed', diff --git a/integration-tests/sdk-typescript/mcp-server.test.ts b/integration-tests/sdk-typescript/mcp-server.test.ts index 9b3f21938..cf1de26d4 100644 --- a/integration-tests/sdk-typescript/mcp-server.test.ts +++ b/integration-tests/sdk-typescript/mcp-server.test.ts @@ -19,6 +19,7 @@ import { type SDKMessage, type ToolUseBlock, type SDKSystemMessage, + type SDKUserMessage, } from '@qwen-code/sdk'; import { SDKTestHelper, @@ -26,6 +27,7 @@ import { extractText, findToolUseBlocks, createSharedTestOptions, + createResultWaiter, } from './test-helper.js'; const SHARED_TEST_OPTIONS = { @@ -296,6 +298,176 @@ describe('MCP Server Integration (E2E)', () => { await q.close(); } }); + + it('should support multi-turn asyncGenerator prompt with MCP tools', async () => { + const resultWaiter = createResultWaiter(2); + + async function* createMultiTurnPrompt(): AsyncIterable { + const sessionId = crypto.randomUUID(); + + yield { + type: 'user', + session_id: sessionId, + message: { + role: 'user', + content: 'Use the add tool to calculate 2 + 3. Give me the result.', + }, + parent_tool_use_id: null, + }; + + await resultWaiter.waitForResult(0); + + yield { + type: 'user', + session_id: sessionId, + message: { + role: 'user', + content: + 'Now use the multiply tool to calculate 5 * 4. Give me the result.', + }, + parent_tool_use_id: null, + }; + + await resultWaiter.waitForResult(1); + } + + const q = query({ + prompt: createMultiTurnPrompt(), + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + debug: false, + mcpServers: { + 'test-math-server': { + command: 'node', + args: [serverScriptPath], + }, + }, + }, + }); + + const messages: SDKMessage[] = []; + let assistantText = ''; + const toolCalls: string[] = []; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks(message); + toolUseBlocks.forEach((block) => { + toolCalls.push(block.name); + }); + assistantText += extractText(message.message.content); + } + } + + expect(toolCalls).toContain('add'); + expect(toolCalls).toContain('multiply'); + expect(assistantText).toMatch(/5/); + expect(assistantText).toMatch(/20/); + + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + } finally { + await q.close(); + } + }); + + it('should support multi-turn MCP tools with canUseTool', async () => { + const canUseToolCalls: Array<{ toolName: string }> = []; + const resultWaiter = createResultWaiter(2); + + async function* createMultiTurnPrompt(): AsyncIterable { + const sessionId = crypto.randomUUID(); + + yield { + type: 'user', + session_id: sessionId, + message: { + role: 'user', + content: 'Use the add tool to calculate 9 + 1. Give me the result.', + }, + parent_tool_use_id: null, + }; + + await resultWaiter.waitForResult(0); + + yield { + type: 'user', + session_id: sessionId, + message: { + role: 'user', + content: + 'Now use the multiply tool to calculate 4 * 3. Give me the result.', + }, + parent_tool_use_id: null, + }; + + await resultWaiter.waitForResult(1); + } + + const q = query({ + prompt: createMultiTurnPrompt(), + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + permissionMode: 'default', + canUseTool: async (toolName) => { + canUseToolCalls.push({ toolName }); + return { + behavior: 'allow', + updatedInput: {}, + }; + }, + debug: false, + mcpServers: { + 'test-math-server': { + command: 'node', + args: [serverScriptPath], + }, + }, + }, + }); + + const messages: SDKMessage[] = []; + let assistantText = ''; + const toolCalls: string[] = []; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks(message); + toolUseBlocks.forEach((block) => { + toolCalls.push(block.name); + }); + assistantText += extractText(message.message.content); + } + } + + expect(toolCalls).toContain('add'); + expect(toolCalls).toContain('multiply'); + expect(canUseToolCalls.map((call) => call.toolName)).toEqual( + expect.arrayContaining(['add', 'multiply']), + ); + expect(assistantText).toMatch(/10/); + expect(assistantText).toMatch(/12/); + + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + } finally { + await q.close(); + } + }); }); describe('MCP Tool Message Flow', () => { diff --git a/integration-tests/sdk-typescript/multi-turn.test.ts b/integration-tests/sdk-typescript/multi-turn.test.ts index c1b96cc7c..4cf845fc5 100644 --- a/integration-tests/sdk-typescript/multi-turn.test.ts +++ b/integration-tests/sdk-typescript/multi-turn.test.ts @@ -22,7 +22,11 @@ import { type ControlMessage, type ToolUseBlock, } from '@qwen-code/sdk'; -import { SDKTestHelper, createSharedTestOptions } from './test-helper.js'; +import { + SDKTestHelper, + createSharedTestOptions, + createResultWaiter, +} from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); @@ -76,6 +80,8 @@ describe('Multi-Turn Conversations (E2E)', () => { describe('AsyncIterable Prompt Support', () => { it('should handle multi-turn conversation using AsyncIterable prompt', async () => { + const resultWaiter = createResultWaiter(3); + // Create multi-turn conversation generator async function* createMultiTurnConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -90,7 +96,7 @@ describe('Multi-Turn Conversations (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 100)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -102,7 +108,7 @@ describe('Multi-Turn Conversations (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 100)); + await resultWaiter.waitForResult(1); yield { type: 'user', @@ -113,6 +119,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(2); } // Create multi-turn query using AsyncIterable prompt @@ -133,6 +141,9 @@ describe('Multi-Turn Conversations (E2E)', () => { for await (const message of q) { messages.push(message); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { assistantMessages.push(message); const text = extractText(message.message.content); @@ -153,6 +164,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }); it('should maintain session context across turns', async () => { + const resultWaiter = createResultWaiter(2); + async function* createContextualConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -162,12 +175,12 @@ describe('Multi-Turn Conversations (E2E)', () => { message: { role: 'user', content: - 'Suppose we have 3 rabbits and 4 carrots. How many animals are there?', + 'Suppose we have 3 rabbits and 4 carrots. Identify: How many **animals** are there?', }, parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 200)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -178,6 +191,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -193,6 +208,9 @@ describe('Multi-Turn Conversations (E2E)', () => { try { for await (const message of q) { + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { assistantMessages.push(message); } @@ -213,6 +231,8 @@ describe('Multi-Turn Conversations (E2E)', () => { describe('Tool Usage in Multi-Turn', () => { it('should handle tool usage across multiple turns', async () => { + const resultWaiter = createResultWaiter(2); + async function* createToolConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -226,7 +246,7 @@ describe('Multi-Turn Conversations (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 200)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -237,6 +257,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -257,6 +279,9 @@ describe('Multi-Turn Conversations (E2E)', () => { for await (const message of q) { messages.push(message); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { assistantMessages.push(message); const hasToolUseBlock = message.message.content.some( @@ -286,6 +311,8 @@ describe('Multi-Turn Conversations (E2E)', () => { describe('Message Flow and Sequencing', () => { it('should process messages in correct sequence', async () => { + const resultWaiter = createResultWaiter(2); + async function* createSequentialConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -299,7 +326,7 @@ describe('Multi-Turn Conversations (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 100)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -310,6 +337,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -329,6 +358,9 @@ describe('Multi-Turn Conversations (E2E)', () => { const messageType = getMessageType(message); messageSequence.push(messageType); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { const text = extractText(message.message.content); assistantResponses.push(text); @@ -351,6 +383,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }); it('should handle conversation completion correctly', async () => { + const resultWaiter = createResultWaiter(2); + async function* createSimpleConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -364,7 +398,7 @@ describe('Multi-Turn Conversations (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 100)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -375,6 +409,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -394,6 +430,7 @@ describe('Multi-Turn Conversations (E2E)', () => { messageCount++; if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); completedNaturally = true; expect(message.subtype).toBe('success'); } @@ -441,6 +478,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }); it('should handle conversation with delays', async () => { + const resultWaiter = createResultWaiter(2); + async function* createDelayedConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -455,7 +494,7 @@ describe('Multi-Turn Conversations (E2E)', () => { } as SDKUserMessage; // Longer delay to test patience - await new Promise((resolve) => setTimeout(resolve, 500)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -466,6 +505,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -481,6 +522,9 @@ describe('Multi-Turn Conversations (E2E)', () => { try { for await (const message of q) { + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { assistantMessages.push(message); } @@ -495,6 +539,8 @@ describe('Multi-Turn Conversations (E2E)', () => { describe('Partial Messages in Multi-Turn', () => { it('should receive partial messages when includePartialMessages is enabled', async () => { + const resultWaiter = createResultWaiter(2); + async function* createMultiTurnConversation(): AsyncIterable { const sessionId = crypto.randomUUID(); @@ -508,7 +554,7 @@ describe('Multi-Turn Conversations (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 100)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -519,6 +565,8 @@ describe('Multi-Turn Conversations (E2E)', () => { }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -539,6 +587,9 @@ describe('Multi-Turn Conversations (E2E)', () => { for await (const message of q) { messages.push(message); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKPartialAssistantMessage(message)) { partialMessageCount++; } diff --git a/integration-tests/sdk-typescript/permission-control.test.ts b/integration-tests/sdk-typescript/permission-control.test.ts index eee344755..4c253dc28 100644 --- a/integration-tests/sdk-typescript/permission-control.test.ts +++ b/integration-tests/sdk-typescript/permission-control.test.ts @@ -31,6 +31,7 @@ import { hasErrorToolResults, findSystemMessage, findToolCalls, + createResultWaiter, } from './test-helper.js'; const TEST_TIMEOUT = 30000; @@ -44,6 +45,7 @@ const SHARED_TEST_OPTIONS = createSharedTestOptions(); function createStreamingInputWithControlPoint( firstMessage: string, secondMessage: string, + resultWaiter: { waitForResult: (index: number) => Promise }, ): { generator: AsyncIterable; resume: () => void; @@ -66,7 +68,7 @@ function createStreamingInputWithControlPoint( parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 200)); + await resultWaiter.waitForResult(0); await resumePromise; @@ -81,6 +83,8 @@ function createStreamingInputWithControlPoint( }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); })(); const resume = () => { @@ -320,9 +324,11 @@ describe('Permission Control (E2E)', () => { describe('setPermissionMode API', () => { it('should change permission mode from default to yolo', async () => { + const resultWaiter = createResultWaiter(2); const { generator, resume } = createStreamingInputWithControlPoint( 'What is 1 + 1?', 'What is 2 + 2?', + resultWaiter, ); const q = query({ @@ -361,6 +367,9 @@ describe('Permission Control (E2E)', () => { resolvers.second?.(); } } + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } } })(); @@ -397,9 +406,11 @@ describe('Permission Control (E2E)', () => { }); it('should change permission mode from yolo to plan', async () => { + const resultWaiter = createResultWaiter(2); const { generator, resume } = createStreamingInputWithControlPoint( 'What is 3 + 3?', 'What is 4 + 4?', + resultWaiter, ); const q = query({ @@ -437,6 +448,9 @@ describe('Permission Control (E2E)', () => { resolvers.second?.(); } } + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } } })(); @@ -473,9 +487,11 @@ describe('Permission Control (E2E)', () => { }); it('should change permission mode to auto-edit', async () => { + const resultWaiter = createResultWaiter(2); const { generator, resume } = createStreamingInputWithControlPoint( 'What is 5 + 5?', 'What is 6 + 6?', + resultWaiter, ); const q = query({ @@ -513,6 +529,9 @@ describe('Permission Control (E2E)', () => { resolvers.second?.(); } } + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } } })(); @@ -584,9 +603,11 @@ describe('Permission Control (E2E)', () => { input: Record; }> = []; + const resultWaiter = createResultWaiter(2); const { generator, resume } = createStreamingInputWithControlPoint( 'Create a file named first.txt', 'Create a file named second.txt', + resultWaiter, ); const q = query({ @@ -630,6 +651,7 @@ describe('Permission Control (E2E)', () => { secondResponseReceived = true; resolvers.second?.(); } + resultWaiter.notifyResult(); } } })(); diff --git a/integration-tests/sdk-typescript/system-control.test.ts b/integration-tests/sdk-typescript/system-control.test.ts index a977e6471..0ae28c4c5 100644 --- a/integration-tests/sdk-typescript/system-control.test.ts +++ b/integration-tests/sdk-typescript/system-control.test.ts @@ -8,9 +8,14 @@ import { query, isSDKAssistantMessage, isSDKSystemMessage, + isSDKResultMessage, type SDKUserMessage, } from '@qwen-code/sdk'; -import { SDKTestHelper, createSharedTestOptions } from './test-helper.js'; +import { + SDKTestHelper, + createSharedTestOptions, + createResultWaiter, +} from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); @@ -26,6 +31,7 @@ const SHARED_TEST_OPTIONS = createSharedTestOptions(); function createStreamingInputWithControlPoint( firstMessage: string, secondMessage: string, + resultWaiter: { waitForResult: (index: number) => Promise }, ): { generator: AsyncIterable; resume: () => void; @@ -48,7 +54,7 @@ function createStreamingInputWithControlPoint( parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 200)); + await resultWaiter.waitForResult(0); await resumePromise; @@ -63,6 +69,8 @@ function createStreamingInputWithControlPoint( }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(1); })(); const resume = () => { @@ -89,9 +97,11 @@ describe('System Control (E2E)', () => { describe('setModel API', () => { it('should change model dynamically during streaming input', async () => { + const resultWaiter = createResultWaiter(2); const { generator, resume } = createStreamingInputWithControlPoint( 'Tell me the model name.', 'Tell me the model name now again.', + resultWaiter, ); const q = query({ @@ -126,6 +136,9 @@ describe('System Control (E2E)', () => { if (isSDKSystemMessage(message)) { systemMessages.push({ model: message.model }); } + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { if (!firstResponseReceived) { firstResponseReceived = true; @@ -181,6 +194,7 @@ describe('System Control (E2E)', () => { it('should handle multiple model changes in sequence', async () => { const sessionId = crypto.randomUUID(); + const resultWaiter = createResultWaiter(3); let resumeResolve1: (() => void) | null = null; let resumeResolve2: (() => void) | null = null; const resumePromise1 = new Promise((resolve) => { @@ -198,7 +212,7 @@ describe('System Control (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 200)); + await resultWaiter.waitForResult(0); await resumePromise1; await new Promise((resolve) => setTimeout(resolve, 200)); @@ -209,7 +223,7 @@ describe('System Control (E2E)', () => { parent_tool_use_id: null, } as SDKUserMessage; - await new Promise((resolve) => setTimeout(resolve, 200)); + await resultWaiter.waitForResult(1); await resumePromise2; await new Promise((resolve) => setTimeout(resolve, 200)); @@ -219,6 +233,8 @@ describe('System Control (E2E)', () => { message: { role: 'user', content: 'Third message' }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(2); })(); const q = query({ @@ -246,6 +262,9 @@ describe('System Control (E2E)', () => { if (isSDKSystemMessage(message)) { systemMessages.push({ model: message.model }); } + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } if (isSDKAssistantMessage(message)) { if (responseCount < resolvers.length) { resolvers[responseCount]?.(); @@ -318,6 +337,7 @@ describe('System Control (E2E)', () => { describe('supportedCommands API', () => { it('should return list of supported slash commands', async () => { const sessionId = crypto.randomUUID(); + const resultWaiter = createResultWaiter(1); const generator = (async function* () { yield { type: 'user', @@ -325,6 +345,8 @@ describe('System Control (E2E)', () => { message: { role: 'user', content: 'Hello' }, parent_tool_use_id: null, } as SDKUserMessage; + + await resultWaiter.waitForResult(0); })(); const q = query({ @@ -343,6 +365,9 @@ describe('System Control (E2E)', () => { const messageConsumer = (async () => { try { for await (const _message of q) { + if (isSDKResultMessage(_message)) { + resultWaiter.notifyResult(); + } // Just consume messages } } catch (error) { diff --git a/integration-tests/sdk-typescript/test-helper.ts b/integration-tests/sdk-typescript/test-helper.ts index d7efc026c..07f44f890 100644 --- a/integration-tests/sdk-typescript/test-helper.ts +++ b/integration-tests/sdk-typescript/test-helper.ts @@ -655,6 +655,29 @@ export function hasErrorToolResults(messages: SDKMessage[]): boolean { // Streaming Input Utilities // ============================================================================ +export function createResultWaiter(expectedResults: number): { + waitForResult: (index: number) => Promise; + notifyResult: () => void; +} { + const resolvers: Array<() => void> = []; + const promises = Array.from({ length: expectedResults }, () => { + return new Promise((resolve) => { + resolvers.push(resolve); + }); + }); + let resolvedCount = 0; + + return { + waitForResult: (index: number) => promises[index], + notifyResult: () => { + if (resolvedCount < resolvers.length) { + resolvers[resolvedCount]?.(); + resolvedCount += 1; + } + }, + }; +} + /** * Create a simple streaming input from an array of message contents */ diff --git a/integration-tests/sdk-typescript/tool-control.test.ts b/integration-tests/sdk-typescript/tool-control.test.ts index 90819aad1..aecb98ae6 100644 --- a/integration-tests/sdk-typescript/tool-control.test.ts +++ b/integration-tests/sdk-typescript/tool-control.test.ts @@ -15,6 +15,7 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { query, isSDKAssistantMessage, + isSDKResultMessage, type SDKMessage, type SDKUserMessage, } from '@qwen-code/sdk'; @@ -25,6 +26,7 @@ import { findToolResults, assertSuccessfulCompletion, createSharedTestOptions, + createResultWaiter, } from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); @@ -751,6 +753,7 @@ describe('Tool Control Parameters (E2E)', () => { async () => { await helper.createFile('test.txt', 'original content'); + const resultWaiter = createResultWaiter(1); const canUseToolCalls: Array<{ toolName: string; input: Record; @@ -768,7 +771,7 @@ describe('Tool Control Parameters (E2E)', () => { parent_tool_use_id: null, }; - await new Promise((resolve) => setTimeout(resolve, 3000)); + await resultWaiter.waitForResult(0); } const q = query({ @@ -795,6 +798,9 @@ describe('Tool Control Parameters (E2E)', () => { try { for await (const message of q) { messages.push(message); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } } const toolCalls = findToolCalls(messages); @@ -827,6 +833,7 @@ describe('Tool Control Parameters (E2E)', () => { async () => { await helper.createFile('test.txt', 'original content'); + const resultWaiter = createResultWaiter(1); // Create an async generator that yields a single message async function* createPrompt(): AsyncIterable { yield { @@ -838,7 +845,7 @@ describe('Tool Control Parameters (E2E)', () => { }, parent_tool_use_id: null, }; - await new Promise((resolve) => setTimeout(resolve, 3000)); + await resultWaiter.waitForResult(0); } const q = query({ @@ -866,6 +873,9 @@ describe('Tool Control Parameters (E2E)', () => { try { for await (const message of q) { messages.push(message); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } } // write_file should have been attempted but stream was closed @@ -892,6 +902,7 @@ describe('Tool Control Parameters (E2E)', () => { async () => { await helper.createFile('data.txt', 'initial data'); + const resultWaiter = createResultWaiter(2); const canUseToolCalls: string[] = []; // Create an async generator that yields multiple messages @@ -908,8 +919,7 @@ describe('Tool Control Parameters (E2E)', () => { parent_tool_use_id: null, }; - // Small delay to simulate multi-turn conversation - await new Promise((resolve) => setTimeout(resolve, 100)); + await resultWaiter.waitForResult(0); yield { type: 'user', @@ -920,6 +930,8 @@ describe('Tool Control Parameters (E2E)', () => { }, parent_tool_use_id: null, }; + + await resultWaiter.waitForResult(1); } const q = query({ @@ -942,6 +954,9 @@ describe('Tool Control Parameters (E2E)', () => { try { for await (const message of q) { messages.push(message); + if (isSDKResultMessage(message)) { + resultWaiter.notifyResult(); + } } const toolCalls = findToolCalls(messages); @@ -951,17 +966,14 @@ describe('Tool Control Parameters (E2E)', () => { expect(toolNames).toContain('read_file'); expect(toolNames).toContain('write_file'); - // canUseTool should not be called once stream is closed - expect(canUseToolCalls).toHaveLength(0); + expect(canUseToolCalls).toContain('write_file'); const writeFileResults = findToolResults(messages, 'write_file'); expect(writeFileResults.length).toBeGreaterThan(0); - for (const result of writeFileResults) { - expect(result.content).toContain('Error: Input closed'); - } const content = await helper.readFile('data.txt'); - expect(content).toBe('initial data'); + expect(content).toContain('initial data'); + expect(content).toContain(' - updated'); } finally { await q.close(); } diff --git a/packages/sdk-typescript/src/transport/ProcessTransport.ts b/packages/sdk-typescript/src/transport/ProcessTransport.ts index 6d71c69e0..ff4518833 100644 --- a/packages/sdk-typescript/src/transport/ProcessTransport.ts +++ b/packages/sdk-typescript/src/transport/ProcessTransport.ts @@ -282,9 +282,9 @@ export class ProcessTransport implements Transport { if (this.childStdin.writableEnded || this.childStdin.destroyed) { this.inputClosed = true; logger.warn( - `Cannot write to ${this.childStdin.writableEnded ? 'ended' : 'destroyed'} stdin stream, ignoring write`, + `Cannot write to ${this.childStdin.writableEnded ? 'ended' : 'destroyed'} stdin stream`, ); - return; + throw new Error('Input stream closed'); } if (this.childProcess?.killed || this.childProcess?.exitCode !== null) { @@ -319,10 +319,9 @@ export class ProcessTransport implements Transport { errorMsg.includes('write after end'); if (isStreamClosedError) { - // Soft-fail: log and return without throwing or changing ready state this.inputClosed = true; logger.warn(`Stream closed, cannot write: ${errorMsg}`); - return; + throw new Error('Input stream closed'); } // For other errors, maintain original behavior