mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-05-01 21:20:44 +00:00
Merge pull request #1691 from QwenLM/fix/subagent-tool-restriction
fix(core): enforce tool restrictions in subagents
This commit is contained in:
commit
da3dd56fd8
2 changed files with 269 additions and 42 deletions
|
|
@ -38,6 +38,8 @@ import {
|
|||
SubAgentEventEmitter,
|
||||
SubAgentEventType,
|
||||
type SubAgentStreamTextEvent,
|
||||
type SubAgentToolCallEvent,
|
||||
type SubAgentToolResultEvent,
|
||||
} from './subagent-events.js';
|
||||
import type {
|
||||
ModelConfig,
|
||||
|
|
@ -933,5 +935,165 @@ describe('subagent.ts', () => {
|
|||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('runNonInteractive - Tool Restriction Enforcement (Issue #1121)', () => {
|
||||
const promptConfig: PromptConfig = { systemPrompt: 'Execute task.' };
|
||||
|
||||
it('should NOT execute tools that are not in the allowed tools list', async () => {
|
||||
// Define two tools: one allowed (read_file), one not allowed (edit_file)
|
||||
const readFileToolDef: FunctionDeclaration = {
|
||||
name: 'read_file',
|
||||
description: 'Reads a file',
|
||||
parameters: { type: Type.OBJECT, properties: {} },
|
||||
};
|
||||
const editFileToolDef: FunctionDeclaration = {
|
||||
name: 'edit_file',
|
||||
description: 'Edits a file',
|
||||
parameters: { type: Type.OBJECT, properties: {} },
|
||||
};
|
||||
|
||||
// Track which tools were executed
|
||||
const executedTools: string[] = [];
|
||||
|
||||
const readFileInvocation = {
|
||||
params: { path: 'test.txt' },
|
||||
getDescription: vi.fn().mockReturnValue('Read file'),
|
||||
toolLocations: vi.fn().mockReturnValue([]),
|
||||
shouldConfirmExecute: vi.fn().mockResolvedValue(false),
|
||||
execute: vi.fn().mockImplementation(async () => {
|
||||
executedTools.push('read_file');
|
||||
return {
|
||||
llmContent: 'file contents',
|
||||
returnDisplay: 'Read file contents',
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const editFileInvocation = {
|
||||
params: { path: 'test.txt', content: 'malicious content' },
|
||||
getDescription: vi.fn().mockReturnValue('Edit file'),
|
||||
toolLocations: vi.fn().mockReturnValue([]),
|
||||
shouldConfirmExecute: vi.fn().mockResolvedValue(false),
|
||||
execute: vi.fn().mockImplementation(async () => {
|
||||
executedTools.push('edit_file');
|
||||
return {
|
||||
llmContent: 'file edited',
|
||||
returnDisplay: 'Edited file',
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const readFileTool = {
|
||||
name: 'read_file',
|
||||
displayName: 'Read File',
|
||||
description: 'Read file contents',
|
||||
kind: 'READ' as const,
|
||||
schema: readFileToolDef,
|
||||
build: vi.fn().mockImplementation(() => readFileInvocation),
|
||||
canUpdateOutput: false,
|
||||
isOutputMarkdown: true,
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
|
||||
const editFileTool = {
|
||||
name: 'edit_file',
|
||||
displayName: 'Edit File',
|
||||
description: 'Edit file contents',
|
||||
kind: 'WRITE' as const,
|
||||
schema: editFileToolDef,
|
||||
build: vi.fn().mockImplementation(() => editFileInvocation),
|
||||
canUpdateOutput: false,
|
||||
isOutputMarkdown: true,
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
|
||||
const { config } = await createMockConfig({
|
||||
// Only return read_file in the filtered list (this is what the subagent should see)
|
||||
getFunctionDeclarationsFiltered: vi
|
||||
.fn()
|
||||
.mockReturnValue([readFileToolDef]),
|
||||
// But the full registry has both tools (simulating the bug)
|
||||
getFunctionDeclarations: vi
|
||||
.fn()
|
||||
.mockReturnValue([readFileToolDef, editFileToolDef]),
|
||||
getTool: vi.fn().mockImplementation((name: string) => {
|
||||
if (name === 'read_file') return readFileTool;
|
||||
if (name === 'edit_file') return editFileTool;
|
||||
return undefined;
|
||||
}),
|
||||
});
|
||||
|
||||
// Only allow read_file in the subagent's tool config
|
||||
const toolConfig: ToolConfig = { tools: ['read_file'] };
|
||||
|
||||
// Model calls BOTH read_file (allowed) AND edit_file (NOT allowed)
|
||||
// This simulates the bug where the model hallucinates an unauthorized tool call
|
||||
mockSendMessageStream.mockImplementation(
|
||||
createMockStream([
|
||||
[
|
||||
{
|
||||
id: 'call_read',
|
||||
name: 'read_file',
|
||||
args: { path: 'test.txt' },
|
||||
},
|
||||
{
|
||||
id: 'call_edit',
|
||||
name: 'edit_file', // This tool is NOT in the allowed list!
|
||||
args: { path: 'test.txt', content: 'malicious content' },
|
||||
},
|
||||
],
|
||||
'stop',
|
||||
]),
|
||||
);
|
||||
|
||||
// Track emitted events
|
||||
const toolCallEvents: SubAgentToolCallEvent[] = [];
|
||||
const toolResultEvents: SubAgentToolResultEvent[] = [];
|
||||
|
||||
// Create event emitter BEFORE the scope and subscribe to events
|
||||
const eventEmitter = new SubAgentEventEmitter();
|
||||
eventEmitter.on(SubAgentEventType.TOOL_CALL, (event: unknown) => {
|
||||
toolCallEvents.push(event as SubAgentToolCallEvent);
|
||||
});
|
||||
eventEmitter.on(SubAgentEventType.TOOL_RESULT, (event: unknown) => {
|
||||
toolResultEvents.push(event as SubAgentToolResultEvent);
|
||||
});
|
||||
|
||||
const scope = await SubAgentScope.create(
|
||||
'test-agent',
|
||||
config,
|
||||
promptConfig,
|
||||
defaultModelConfig,
|
||||
defaultRunConfig,
|
||||
toolConfig,
|
||||
eventEmitter,
|
||||
);
|
||||
|
||||
await scope.runNonInteractive(new ContextState());
|
||||
|
||||
// 1. Only allowed tool should be executed
|
||||
expect(executedTools).toContain('read_file');
|
||||
expect(executedTools).not.toContain('edit_file');
|
||||
expect(editFileInvocation.execute).not.toHaveBeenCalled();
|
||||
|
||||
// 2. TOOL_CALL events should be emitted for BOTH tools (for visibility)
|
||||
expect(toolCallEvents).toHaveLength(2);
|
||||
expect(toolCallEvents.map((e) => e.name)).toContain('read_file');
|
||||
expect(toolCallEvents.map((e) => e.name)).toContain('edit_file');
|
||||
|
||||
// 3. TOOL_RESULT events should be emitted for both
|
||||
expect(toolResultEvents).toHaveLength(2);
|
||||
|
||||
// 4. Verify blocked tool result has success=false and error message
|
||||
const editResult = toolResultEvents.find((e) => e.name === 'edit_file');
|
||||
expect(editResult).toBeDefined();
|
||||
expect(editResult!.success).toBe(false);
|
||||
expect(editResult!.error).toContain('not found');
|
||||
expect(editResult!.callId).toBe('call_edit');
|
||||
|
||||
// 5. Verify allowed tool result has success=true
|
||||
const readResult = toolResultEvents.find((e) => e.name === 'read_file');
|
||||
expect(readResult).toBeDefined();
|
||||
expect(readResult!.success).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -487,6 +487,7 @@ export class SubAgentScope {
|
|||
abortController,
|
||||
promptId,
|
||||
turnCounter,
|
||||
toolsList,
|
||||
currentResponseId,
|
||||
);
|
||||
} else {
|
||||
|
|
@ -585,10 +586,67 @@ export class SubAgentScope {
|
|||
abortController: AbortController,
|
||||
promptId: string,
|
||||
currentRound: number,
|
||||
toolsList: FunctionDeclaration[],
|
||||
responseId?: string,
|
||||
): Promise<Content[]> {
|
||||
const toolResponseParts: Part[] = [];
|
||||
|
||||
// Build allowed tool names set for filtering
|
||||
const allowedToolNames = new Set(toolsList.map((t) => t.name));
|
||||
|
||||
// Filter unauthorized tool calls before scheduling
|
||||
const authorizedCalls: FunctionCall[] = [];
|
||||
for (const fc of functionCalls) {
|
||||
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
|
||||
|
||||
if (!allowedToolNames.has(fc.name)) {
|
||||
const toolName = String(fc.name);
|
||||
const errorMessage = `Tool "${toolName}" not found. Tools must use the exact names provided.`;
|
||||
|
||||
// Emit TOOL_CALL event for visibility
|
||||
this.eventEmitter?.emit(SubAgentEventType.TOOL_CALL, {
|
||||
subagentId: this.subagentId,
|
||||
round: currentRound,
|
||||
callId,
|
||||
name: toolName,
|
||||
args: fc.args ?? {},
|
||||
description: `Tool "${toolName}" not found`,
|
||||
timestamp: Date.now(),
|
||||
} as SubAgentToolCallEvent);
|
||||
|
||||
// Build function response part (used for both event and LLM)
|
||||
const functionResponsePart = {
|
||||
functionResponse: {
|
||||
id: callId,
|
||||
name: toolName,
|
||||
response: { error: errorMessage },
|
||||
},
|
||||
};
|
||||
|
||||
// Emit TOOL_RESULT event with error (include responseParts for UI rendering)
|
||||
this.eventEmitter?.emit(SubAgentEventType.TOOL_RESULT, {
|
||||
subagentId: this.subagentId,
|
||||
round: currentRound,
|
||||
callId,
|
||||
name: toolName,
|
||||
success: false,
|
||||
error: errorMessage,
|
||||
responseParts: [functionResponsePart],
|
||||
resultDisplay: errorMessage,
|
||||
durationMs: 0,
|
||||
timestamp: Date.now(),
|
||||
} as SubAgentToolResultEvent);
|
||||
|
||||
// Record blocked tool call in stats
|
||||
this.recordToolCallStats(toolName, false, 0, errorMessage);
|
||||
|
||||
// Add function response for LLM
|
||||
toolResponseParts.push(functionResponsePart);
|
||||
continue;
|
||||
}
|
||||
authorizedCalls.push(fc);
|
||||
}
|
||||
|
||||
// Build scheduler
|
||||
const responded = new Set<string>();
|
||||
let resolveBatch: (() => void) | null = null;
|
||||
|
|
@ -605,33 +663,8 @@ export class SubAgentScope {
|
|||
? call.response.error?.message
|
||||
: undefined;
|
||||
|
||||
// Update aggregate stats
|
||||
this.executionStats.totalToolCalls += 1;
|
||||
if (success) {
|
||||
this.executionStats.successfulToolCalls += 1;
|
||||
} else {
|
||||
this.executionStats.failedToolCalls += 1;
|
||||
}
|
||||
|
||||
// Per-tool usage
|
||||
const tu = this.toolUsage.get(toolName) || {
|
||||
count: 0,
|
||||
success: 0,
|
||||
failure: 0,
|
||||
totalDurationMs: 0,
|
||||
averageDurationMs: 0,
|
||||
};
|
||||
tu.count += 1;
|
||||
if (success) {
|
||||
tu.success += 1;
|
||||
} else {
|
||||
tu.failure += 1;
|
||||
tu.lastError = errorMessage || 'Unknown error';
|
||||
}
|
||||
tu.totalDurationMs = (tu.totalDurationMs || 0) + duration;
|
||||
tu.averageDurationMs =
|
||||
tu.count > 0 ? tu.totalDurationMs / tu.count : 0;
|
||||
this.toolUsage.set(toolName, tu);
|
||||
// Record stats
|
||||
this.recordToolCallStats(toolName, success, duration, errorMessage);
|
||||
|
||||
// Emit tool result event
|
||||
this.eventEmitter?.emit(SubAgentEventType.TOOL_RESULT, {
|
||||
|
|
@ -642,12 +675,6 @@ export class SubAgentScope {
|
|||
success,
|
||||
error: errorMessage,
|
||||
responseParts: call.response.responseParts,
|
||||
/**
|
||||
* Tools like todoWrite will add some extra contents to the result,
|
||||
* making it unable to deserialize the `responseParts` to a JSON object.
|
||||
* While `resultDisplay` is normally a string, if not we stringify it,
|
||||
* so that we can deserialize it to a JSON object when needed.
|
||||
*/
|
||||
resultDisplay: call.response.resultDisplay
|
||||
? typeof call.response.resultDisplay === 'string'
|
||||
? call.response.resultDisplay
|
||||
|
|
@ -657,14 +684,6 @@ export class SubAgentScope {
|
|||
timestamp: Date.now(),
|
||||
} as SubAgentToolResultEvent);
|
||||
|
||||
// Update statistics service
|
||||
this.stats.recordToolCall(
|
||||
toolName,
|
||||
success,
|
||||
duration,
|
||||
this.toolUsage.get(toolName)?.lastError,
|
||||
);
|
||||
|
||||
// post-tool hook
|
||||
await this.hooks?.postToolUse?.({
|
||||
subagentId: this.subagentId,
|
||||
|
|
@ -736,7 +755,7 @@ export class SubAgentScope {
|
|||
});
|
||||
|
||||
// Prepare requests and emit TOOL_CALL events
|
||||
const requests: ToolCallRequestInfo[] = functionCalls.map((fc) => {
|
||||
const requests: ToolCallRequestInfo[] = authorizedCalls.map((fc) => {
|
||||
const toolName = String(fc.name || 'unknown');
|
||||
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
|
||||
const args = (fc.args ?? {}) as Record<string, unknown>;
|
||||
|
|
@ -902,6 +921,52 @@ export class SubAgentScope {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Records tool call statistics for both successful and failed tool calls.
|
||||
* This includes updating aggregate stats, per-tool usage, and the statistics service.
|
||||
*/
|
||||
private recordToolCallStats(
|
||||
toolName: string,
|
||||
success: boolean,
|
||||
durationMs: number,
|
||||
errorMessage?: string,
|
||||
): void {
|
||||
// Update aggregate stats
|
||||
this.executionStats.totalToolCalls += 1;
|
||||
if (success) {
|
||||
this.executionStats.successfulToolCalls += 1;
|
||||
} else {
|
||||
this.executionStats.failedToolCalls += 1;
|
||||
}
|
||||
|
||||
// Per-tool usage
|
||||
const tu = this.toolUsage.get(toolName) || {
|
||||
count: 0,
|
||||
success: 0,
|
||||
failure: 0,
|
||||
totalDurationMs: 0,
|
||||
averageDurationMs: 0,
|
||||
};
|
||||
tu.count += 1;
|
||||
if (success) {
|
||||
tu.success += 1;
|
||||
} else {
|
||||
tu.failure += 1;
|
||||
tu.lastError = errorMessage || 'Unknown error';
|
||||
}
|
||||
tu.totalDurationMs = (tu.totalDurationMs || 0) + durationMs;
|
||||
tu.averageDurationMs = tu.count > 0 ? tu.totalDurationMs / tu.count : 0;
|
||||
this.toolUsage.set(toolName, tu);
|
||||
|
||||
// Update statistics service
|
||||
this.stats.recordToolCall(
|
||||
toolName,
|
||||
success,
|
||||
durationMs,
|
||||
this.toolUsage.get(toolName)?.lastError,
|
||||
);
|
||||
}
|
||||
|
||||
private buildChatSystemPrompt(context: ContextState): string {
|
||||
if (!this.promptConfig.systemPrompt) {
|
||||
// This should ideally be caught in createChatObject, but serves as a safeguard.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue