Merge pull request #1691 from QwenLM/fix/subagent-tool-restriction

fix(core): enforce tool restrictions in subagents
This commit is contained in:
pomelo 2026-02-04 14:30:46 +08:00 committed by GitHub
commit da3dd56fd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 269 additions and 42 deletions

View file

@ -38,6 +38,8 @@ import {
SubAgentEventEmitter,
SubAgentEventType,
type SubAgentStreamTextEvent,
type SubAgentToolCallEvent,
type SubAgentToolResultEvent,
} from './subagent-events.js';
import type {
ModelConfig,
@ -933,5 +935,165 @@ describe('subagent.ts', () => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
});
});
describe('runNonInteractive - Tool Restriction Enforcement (Issue #1121)', () => {
const promptConfig: PromptConfig = { systemPrompt: 'Execute task.' };
it('should NOT execute tools that are not in the allowed tools list', async () => {
// Define two tools: one allowed (read_file), one not allowed (edit_file)
const readFileToolDef: FunctionDeclaration = {
name: 'read_file',
description: 'Reads a file',
parameters: { type: Type.OBJECT, properties: {} },
};
const editFileToolDef: FunctionDeclaration = {
name: 'edit_file',
description: 'Edits a file',
parameters: { type: Type.OBJECT, properties: {} },
};
// Track which tools were executed
const executedTools: string[] = [];
const readFileInvocation = {
params: { path: 'test.txt' },
getDescription: vi.fn().mockReturnValue('Read file'),
toolLocations: vi.fn().mockReturnValue([]),
shouldConfirmExecute: vi.fn().mockResolvedValue(false),
execute: vi.fn().mockImplementation(async () => {
executedTools.push('read_file');
return {
llmContent: 'file contents',
returnDisplay: 'Read file contents',
};
}),
};
const editFileInvocation = {
params: { path: 'test.txt', content: 'malicious content' },
getDescription: vi.fn().mockReturnValue('Edit file'),
toolLocations: vi.fn().mockReturnValue([]),
shouldConfirmExecute: vi.fn().mockResolvedValue(false),
execute: vi.fn().mockImplementation(async () => {
executedTools.push('edit_file');
return {
llmContent: 'file edited',
returnDisplay: 'Edited file',
};
}),
};
const readFileTool = {
name: 'read_file',
displayName: 'Read File',
description: 'Read file contents',
kind: 'READ' as const,
schema: readFileToolDef,
build: vi.fn().mockImplementation(() => readFileInvocation),
canUpdateOutput: false,
isOutputMarkdown: true,
} as unknown as AnyDeclarativeTool;
const editFileTool = {
name: 'edit_file',
displayName: 'Edit File',
description: 'Edit file contents',
kind: 'WRITE' as const,
schema: editFileToolDef,
build: vi.fn().mockImplementation(() => editFileInvocation),
canUpdateOutput: false,
isOutputMarkdown: true,
} as unknown as AnyDeclarativeTool;
const { config } = await createMockConfig({
// Only return read_file in the filtered list (this is what the subagent should see)
getFunctionDeclarationsFiltered: vi
.fn()
.mockReturnValue([readFileToolDef]),
// But the full registry has both tools (simulating the bug)
getFunctionDeclarations: vi
.fn()
.mockReturnValue([readFileToolDef, editFileToolDef]),
getTool: vi.fn().mockImplementation((name: string) => {
if (name === 'read_file') return readFileTool;
if (name === 'edit_file') return editFileTool;
return undefined;
}),
});
// Only allow read_file in the subagent's tool config
const toolConfig: ToolConfig = { tools: ['read_file'] };
// Model calls BOTH read_file (allowed) AND edit_file (NOT allowed)
// This simulates the bug where the model hallucinates an unauthorized tool call
mockSendMessageStream.mockImplementation(
createMockStream([
[
{
id: 'call_read',
name: 'read_file',
args: { path: 'test.txt' },
},
{
id: 'call_edit',
name: 'edit_file', // This tool is NOT in the allowed list!
args: { path: 'test.txt', content: 'malicious content' },
},
],
'stop',
]),
);
// Track emitted events
const toolCallEvents: SubAgentToolCallEvent[] = [];
const toolResultEvents: SubAgentToolResultEvent[] = [];
// Create event emitter BEFORE the scope and subscribe to events
const eventEmitter = new SubAgentEventEmitter();
eventEmitter.on(SubAgentEventType.TOOL_CALL, (event: unknown) => {
toolCallEvents.push(event as SubAgentToolCallEvent);
});
eventEmitter.on(SubAgentEventType.TOOL_RESULT, (event: unknown) => {
toolResultEvents.push(event as SubAgentToolResultEvent);
});
const scope = await SubAgentScope.create(
'test-agent',
config,
promptConfig,
defaultModelConfig,
defaultRunConfig,
toolConfig,
eventEmitter,
);
await scope.runNonInteractive(new ContextState());
// 1. Only allowed tool should be executed
expect(executedTools).toContain('read_file');
expect(executedTools).not.toContain('edit_file');
expect(editFileInvocation.execute).not.toHaveBeenCalled();
// 2. TOOL_CALL events should be emitted for BOTH tools (for visibility)
expect(toolCallEvents).toHaveLength(2);
expect(toolCallEvents.map((e) => e.name)).toContain('read_file');
expect(toolCallEvents.map((e) => e.name)).toContain('edit_file');
// 3. TOOL_RESULT events should be emitted for both
expect(toolResultEvents).toHaveLength(2);
// 4. Verify blocked tool result has success=false and error message
const editResult = toolResultEvents.find((e) => e.name === 'edit_file');
expect(editResult).toBeDefined();
expect(editResult!.success).toBe(false);
expect(editResult!.error).toContain('not found');
expect(editResult!.callId).toBe('call_edit');
// 5. Verify allowed tool result has success=true
const readResult = toolResultEvents.find((e) => e.name === 'read_file');
expect(readResult).toBeDefined();
expect(readResult!.success).toBe(true);
});
});
});
});

View file

@ -487,6 +487,7 @@ export class SubAgentScope {
abortController,
promptId,
turnCounter,
toolsList,
currentResponseId,
);
} else {
@ -585,10 +586,67 @@ export class SubAgentScope {
abortController: AbortController,
promptId: string,
currentRound: number,
toolsList: FunctionDeclaration[],
responseId?: string,
): Promise<Content[]> {
const toolResponseParts: Part[] = [];
// Build allowed tool names set for filtering
const allowedToolNames = new Set(toolsList.map((t) => t.name));
// Filter unauthorized tool calls before scheduling
const authorizedCalls: FunctionCall[] = [];
for (const fc of functionCalls) {
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
if (!allowedToolNames.has(fc.name)) {
const toolName = String(fc.name);
const errorMessage = `Tool "${toolName}" not found. Tools must use the exact names provided.`;
// Emit TOOL_CALL event for visibility
this.eventEmitter?.emit(SubAgentEventType.TOOL_CALL, {
subagentId: this.subagentId,
round: currentRound,
callId,
name: toolName,
args: fc.args ?? {},
description: `Tool "${toolName}" not found`,
timestamp: Date.now(),
} as SubAgentToolCallEvent);
// Build function response part (used for both event and LLM)
const functionResponsePart = {
functionResponse: {
id: callId,
name: toolName,
response: { error: errorMessage },
},
};
// Emit TOOL_RESULT event with error (include responseParts for UI rendering)
this.eventEmitter?.emit(SubAgentEventType.TOOL_RESULT, {
subagentId: this.subagentId,
round: currentRound,
callId,
name: toolName,
success: false,
error: errorMessage,
responseParts: [functionResponsePart],
resultDisplay: errorMessage,
durationMs: 0,
timestamp: Date.now(),
} as SubAgentToolResultEvent);
// Record blocked tool call in stats
this.recordToolCallStats(toolName, false, 0, errorMessage);
// Add function response for LLM
toolResponseParts.push(functionResponsePart);
continue;
}
authorizedCalls.push(fc);
}
// Build scheduler
const responded = new Set<string>();
let resolveBatch: (() => void) | null = null;
@ -605,33 +663,8 @@ export class SubAgentScope {
? call.response.error?.message
: undefined;
// Update aggregate stats
this.executionStats.totalToolCalls += 1;
if (success) {
this.executionStats.successfulToolCalls += 1;
} else {
this.executionStats.failedToolCalls += 1;
}
// Per-tool usage
const tu = this.toolUsage.get(toolName) || {
count: 0,
success: 0,
failure: 0,
totalDurationMs: 0,
averageDurationMs: 0,
};
tu.count += 1;
if (success) {
tu.success += 1;
} else {
tu.failure += 1;
tu.lastError = errorMessage || 'Unknown error';
}
tu.totalDurationMs = (tu.totalDurationMs || 0) + duration;
tu.averageDurationMs =
tu.count > 0 ? tu.totalDurationMs / tu.count : 0;
this.toolUsage.set(toolName, tu);
// Record stats
this.recordToolCallStats(toolName, success, duration, errorMessage);
// Emit tool result event
this.eventEmitter?.emit(SubAgentEventType.TOOL_RESULT, {
@ -642,12 +675,6 @@ export class SubAgentScope {
success,
error: errorMessage,
responseParts: call.response.responseParts,
/**
* Tools like todoWrite will add some extra contents to the result,
* making it unable to deserialize the `responseParts` to a JSON object.
* While `resultDisplay` is normally a string, if not we stringify it,
* so that we can deserialize it to a JSON object when needed.
*/
resultDisplay: call.response.resultDisplay
? typeof call.response.resultDisplay === 'string'
? call.response.resultDisplay
@ -657,14 +684,6 @@ export class SubAgentScope {
timestamp: Date.now(),
} as SubAgentToolResultEvent);
// Update statistics service
this.stats.recordToolCall(
toolName,
success,
duration,
this.toolUsage.get(toolName)?.lastError,
);
// post-tool hook
await this.hooks?.postToolUse?.({
subagentId: this.subagentId,
@ -736,7 +755,7 @@ export class SubAgentScope {
});
// Prepare requests and emit TOOL_CALL events
const requests: ToolCallRequestInfo[] = functionCalls.map((fc) => {
const requests: ToolCallRequestInfo[] = authorizedCalls.map((fc) => {
const toolName = String(fc.name || 'unknown');
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
const args = (fc.args ?? {}) as Record<string, unknown>;
@ -902,6 +921,52 @@ export class SubAgentScope {
}
}
/**
* Records tool call statistics for both successful and failed tool calls.
* This includes updating aggregate stats, per-tool usage, and the statistics service.
*/
private recordToolCallStats(
toolName: string,
success: boolean,
durationMs: number,
errorMessage?: string,
): void {
// Update aggregate stats
this.executionStats.totalToolCalls += 1;
if (success) {
this.executionStats.successfulToolCalls += 1;
} else {
this.executionStats.failedToolCalls += 1;
}
// Per-tool usage
const tu = this.toolUsage.get(toolName) || {
count: 0,
success: 0,
failure: 0,
totalDurationMs: 0,
averageDurationMs: 0,
};
tu.count += 1;
if (success) {
tu.success += 1;
} else {
tu.failure += 1;
tu.lastError = errorMessage || 'Unknown error';
}
tu.totalDurationMs = (tu.totalDurationMs || 0) + durationMs;
tu.averageDurationMs = tu.count > 0 ? tu.totalDurationMs / tu.count : 0;
this.toolUsage.set(toolName, tu);
// Update statistics service
this.stats.recordToolCall(
toolName,
success,
durationMs,
this.toolUsage.get(toolName)?.lastError,
);
}
private buildChatSystemPrompt(context: ContextState): string {
if (!this.promptConfig.systemPrompt) {
// This should ideally be caught in createChatObject, but serves as a safeguard.