mirror of
https://github.com/badlogic/pi-mono.git
synced 2026-04-28 06:19:43 +00:00
fix(anthropic): harden tool-call streaming and recovery
This commit is contained in:
parent
fd91acec8b
commit
a74ca758ab
7 changed files with 289 additions and 14 deletions
|
|
@ -476,6 +476,16 @@ async function prepareToolCall(
|
|||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
): Promise<PreparedToolCall | ImmediateToolCallOutcome> {
|
||||
if (toolCall.argumentsParseError) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(
|
||||
`Invalid tool arguments JSON for "${toolCall.name}": ${toolCall.argumentsParseError}`,
|
||||
),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const tool = currentContext.tools?.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -307,6 +307,94 @@ describe("agentLoop with AgentMessage", () => {
|
|||
}
|
||||
});
|
||||
|
||||
it("should skip tool execution when tool arguments JSON is invalid", async () => {
|
||||
const toolSchema = Type.Object({});
|
||||
let executed = false;
|
||||
const tool: AgentTool<typeof toolSchema, { ok: true }> = {
|
||||
name: "noop",
|
||||
label: "Noop",
|
||||
description: "No-op tool",
|
||||
parameters: toolSchema,
|
||||
async execute() {
|
||||
executed = true;
|
||||
return {
|
||||
content: [{ type: "text", text: "executed" }],
|
||||
details: { ok: true },
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: "",
|
||||
messages: [],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const userPrompt: AgentMessage = createUserMessage("run noop");
|
||||
const config: AgentLoopConfig = {
|
||||
model: createModel(),
|
||||
convertToLlm: identityConverter,
|
||||
};
|
||||
|
||||
let callIndex = 0;
|
||||
const events: AgentEvent[] = [];
|
||||
const stream = agentLoop([userPrompt], context, config, undefined, () => {
|
||||
const mockStream = new MockAssistantStream();
|
||||
queueMicrotask(() => {
|
||||
if (callIndex === 0) {
|
||||
const message = createAssistantMessage(
|
||||
[
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "tool-1",
|
||||
name: "noop",
|
||||
arguments: {},
|
||||
argumentsParseError: "parse error",
|
||||
},
|
||||
],
|
||||
"toolUse",
|
||||
);
|
||||
mockStream.push({ type: "done", reason: "toolUse", message });
|
||||
} else {
|
||||
const message = createAssistantMessage([{ type: "text", text: "done" }]);
|
||||
mockStream.push({ type: "done", reason: "stop", message });
|
||||
}
|
||||
callIndex++;
|
||||
});
|
||||
return mockStream;
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(executed).toBe(false);
|
||||
|
||||
const toolExecutionEnd = events.find(
|
||||
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> =>
|
||||
event.type === "tool_execution_end" && event.toolCallId === "tool-1",
|
||||
);
|
||||
expect(toolExecutionEnd).toBeDefined();
|
||||
expect(toolExecutionEnd?.isError).toBe(true);
|
||||
|
||||
const toolResultMessage = events.find(
|
||||
(event): event is Extract<AgentEvent, { type: "message_end" }> =>
|
||||
event.type === "message_end" &&
|
||||
event.message.role === "toolResult" &&
|
||||
event.message.toolCallId === "tool-1",
|
||||
);
|
||||
expect(toolResultMessage).toBeDefined();
|
||||
if (toolResultMessage && toolResultMessage.message.role === "toolResult") {
|
||||
expect(toolResultMessage.message.isError).toBe(true);
|
||||
expect(toolResultMessage.message.content).toEqual([
|
||||
{
|
||||
type: "text",
|
||||
text: 'Invalid tool arguments JSON for "noop": parse error',
|
||||
},
|
||||
]);
|
||||
}
|
||||
});
|
||||
|
||||
it("should execute mutated beforeToolCall args without revalidation", async () => {
|
||||
const toolSchema = Type.Object({ value: Type.String() });
|
||||
const executed: Array<string | number> = [];
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import type {
|
|||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { parseStreamingJson, parseStreamingJsonWithIndicator } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
|
|
@ -196,6 +196,35 @@ function mergeHeaders(...headerSources: (Record<string, string> | undefined)[]):
|
|||
return merged;
|
||||
}
|
||||
|
||||
type ParsedToolInput = {
|
||||
arguments: Record<string, unknown>;
|
||||
argumentsParseError?: string;
|
||||
};
|
||||
|
||||
function parseToolInputJson(rawJson: string): ParsedToolInput {
|
||||
const [args, err] = parseStreamingJsonWithIndicator<Record<string, unknown>>(rawJson);
|
||||
if (err !== undefined) {
|
||||
return {
|
||||
arguments: args,
|
||||
argumentsParseError: err,
|
||||
};
|
||||
} else {
|
||||
return { arguments: args };
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeToolInput(input: unknown): ParsedToolInput {
|
||||
if (!input) {
|
||||
return { arguments: {} };
|
||||
} else if (typeof input === "string") {
|
||||
return parseToolInputJson(input);
|
||||
} else {
|
||||
return {
|
||||
arguments: input as Record<string, unknown>,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOptions> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
|
|
@ -256,7 +285,7 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
|
|||
if (nextParams !== undefined) {
|
||||
params = nextParams as MessageCreateParamsStreaming;
|
||||
}
|
||||
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
|
||||
const anthropicStream = await client.messages.create({ ...params, stream: true }, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
|
||||
|
|
@ -304,13 +333,17 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
|
|||
output.content.push(block);
|
||||
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
const normalizedInput = normalizeToolInput(event.content_block.input);
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: isOAuth
|
||||
? fromClaudeCodeName(event.content_block.name, context.tools)
|
||||
: event.content_block.name,
|
||||
arguments: (event.content_block.input as Record<string, any>) ?? {},
|
||||
arguments: normalizedInput.arguments,
|
||||
...(normalizedInput.argumentsParseError
|
||||
? { argumentsParseError: normalizedInput.argumentsParseError }
|
||||
: {}),
|
||||
partialJson: "",
|
||||
index: event.index,
|
||||
};
|
||||
|
|
@ -347,7 +380,7 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
|
|||
const block = blocks[index];
|
||||
if (block && block.type === "toolCall") {
|
||||
block.partialJson += event.delta.partial_json;
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
block.arguments = parseStreamingJson<Record<string, unknown>>(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
|
|
@ -383,7 +416,15 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
|
|||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
if (block.partialJson.trim().length > 0) {
|
||||
const parsedInput = parseToolInputJson(block.partialJson);
|
||||
block.arguments = parsedInput.arguments;
|
||||
if (parsedInput.argumentsParseError) {
|
||||
block.argumentsParseError = parsedInput.argumentsParseError;
|
||||
} else {
|
||||
delete block.argumentsParseError;
|
||||
}
|
||||
}
|
||||
// Finalize in-place and strip the scratch buffer so replay only
|
||||
// carries parsed arguments.
|
||||
delete (block as { partialJson?: string }).partialJson;
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ export interface ToolCall {
|
|||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, any>;
|
||||
argumentsParseError?: string; // Anthropic specific: holds parsing errors from bad JSON
|
||||
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,21 +8,33 @@ import { parse as partialParse } from "partial-json";
|
|||
* @returns Parsed object or empty object if parsing fails
|
||||
*/
|
||||
export function parseStreamingJson<T = any>(partialJson: string | undefined): T {
|
||||
return parseStreamingJsonWithIndicator(partialJson)[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to parse potentially incomplete JSON during streaming.
|
||||
* Always returns a valid object, even if the JSON is incomplete.
|
||||
*
|
||||
* @param partialJson The partial JSON string from streaming
|
||||
* @returns Parsed object or empty object if parsing fails + a non empty string of an error if parsing failed due to invalid JSON.
|
||||
*/
|
||||
export function parseStreamingJsonWithIndicator<T = any>(partialJson: string | undefined): [T, string | undefined] {
|
||||
if (!partialJson || partialJson.trim() === "") {
|
||||
return {} as T;
|
||||
return [{} as T, undefined];
|
||||
}
|
||||
|
||||
// Try standard parsing first (fastest for complete JSON)
|
||||
try {
|
||||
return JSON.parse(partialJson) as T;
|
||||
} catch {
|
||||
return [JSON.parse(partialJson) as T, undefined];
|
||||
} catch (err) {
|
||||
const parseError = `${err}` || "invalid json";
|
||||
// Try partial-json for incomplete JSON
|
||||
try {
|
||||
const result = partialParse(partialJson);
|
||||
return (result ?? {}) as T;
|
||||
return [(result ?? {}) as T, parseError];
|
||||
} catch {
|
||||
// If all parsing fails, return empty object
|
||||
return {} as T;
|
||||
return [{} as T, parseError];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
117
packages/ai/test/anthropic-fine-grained-tool-streaming.test.ts
Normal file
117
packages/ai/test/anthropic-fine-grained-tool-streaming.test.ts
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
import { describe, expect, it, vi } from "vitest";
|
||||
import { getModel } from "../src/models.js";
|
||||
import type { Context } from "../src/types.js";
|
||||
|
||||
const mockState = vi.hoisted(() => ({
|
||||
createCalled: false,
|
||||
streamCalled: false,
|
||||
}));
|
||||
|
||||
vi.mock("@anthropic-ai/sdk", () => {
|
||||
const fakeStream = {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "message_start",
|
||||
message: {
|
||||
id: "msg_test",
|
||||
usage: {
|
||||
input_tokens: 12,
|
||||
output_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
},
|
||||
},
|
||||
};
|
||||
yield {
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: "toolu_test",
|
||||
name: "edit",
|
||||
input: {},
|
||||
},
|
||||
};
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: { type: "input_json_delta", partial_json: '{"project_id": 3' },
|
||||
};
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: { type: "input_json_delta", partial_json: ', "ref": "HEAD"' },
|
||||
};
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: { type: "input_json_delta", partial_json: ', "path": ' },
|
||||
};
|
||||
yield {
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: { type: "input_json_delta", partial_json: "}" },
|
||||
};
|
||||
yield {
|
||||
type: "content_block_stop",
|
||||
index: 0,
|
||||
};
|
||||
yield {
|
||||
type: "message_delta",
|
||||
delta: { stop_reason: "tool_use" },
|
||||
usage: {
|
||||
input_tokens: 12,
|
||||
output_tokens: 4,
|
||||
cache_read_input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
class FakeAnthropic {
|
||||
messages = {
|
||||
create: async (_params: Record<string, unknown>) => {
|
||||
mockState.createCalled = true;
|
||||
return fakeStream;
|
||||
},
|
||||
stream: (_params: Record<string, unknown>) => {
|
||||
mockState.streamCalled = true;
|
||||
throw new Error("messages.stream should not be called");
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return { default: FakeAnthropic };
|
||||
});
|
||||
|
||||
describe("anthropic fine-grained tool streaming", () => {
|
||||
it("keeps recoverable partial arguments when streamed tool-input JSON is malformed", async () => {
|
||||
const model = getModel("anthropic", "claude-sonnet-4-5");
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "run edit", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
const { streamAnthropic } = await import("../src/providers/anthropic.js");
|
||||
const result = await streamAnthropic(model, context, { apiKey: "sk-ant-api03-test" }).result();
|
||||
|
||||
expect(mockState.createCalled).toBe(true);
|
||||
expect(mockState.streamCalled).toBe(false);
|
||||
expect(result.stopReason).toBe("toolUse");
|
||||
expect(result.content).toHaveLength(1);
|
||||
|
||||
const toolCall = result.content[0];
|
||||
expect(toolCall?.type).toBe("toolCall");
|
||||
if (!toolCall || toolCall.type !== "toolCall") {
|
||||
throw new Error("Expected toolCall block");
|
||||
}
|
||||
|
||||
expect(toolCall.arguments).toEqual({
|
||||
project_id: 3,
|
||||
ref: "HEAD",
|
||||
});
|
||||
expect(toolCall.argumentsParseError).toBeTruthy();
|
||||
expect(toolCall.argumentsParseError).toContain("JSON");
|
||||
expect("partialJson" in toolCall).toBe(false);
|
||||
});
|
||||
});
|
||||
|
|
@ -4,7 +4,8 @@ import type { Context } from "../src/types.js";
|
|||
|
||||
const mockState = vi.hoisted(() => ({
|
||||
constructorOpts: undefined as Record<string, unknown> | undefined,
|
||||
streamParams: undefined as Record<string, unknown> | undefined,
|
||||
createParams: undefined as Record<string, unknown> | undefined,
|
||||
streamCalled: false,
|
||||
}));
|
||||
|
||||
vi.mock("@anthropic-ai/sdk", () => {
|
||||
|
|
@ -32,10 +33,14 @@ vi.mock("@anthropic-ai/sdk", () => {
|
|||
mockState.constructorOpts = opts;
|
||||
}
|
||||
messages = {
|
||||
stream: (params: Record<string, unknown>) => {
|
||||
mockState.streamParams = params;
|
||||
create: async (params: Record<string, unknown>) => {
|
||||
mockState.createParams = params;
|
||||
return fakeStream;
|
||||
},
|
||||
stream: (_params: Record<string, unknown>) => {
|
||||
mockState.streamCalled = true;
|
||||
throw new Error("messages.stream should not be called");
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -79,7 +84,8 @@ describe("Copilot Claude via Anthropic Messages", () => {
|
|||
expect(beta).not.toContain("fine-grained-tool-streaming");
|
||||
|
||||
// Payload is valid Anthropic Messages format
|
||||
const params = mockState.streamParams!;
|
||||
expect(mockState.streamCalled).toBe(false);
|
||||
const params = mockState.createParams!;
|
||||
expect(params.model).toBe("claude-sonnet-4");
|
||||
expect(params.stream).toBe(true);
|
||||
expect(params.max_tokens).toBeGreaterThan(0);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue