mirror of
https://github.com/badlogic/pi-mono.git
synced 2026-04-28 06:19:43 +00:00
285 lines
7.9 KiB
TypeScript
285 lines
7.9 KiB
TypeScript
import { existsSync, mkdirSync, rmSync } from "node:fs";
|
|
import { tmpdir } from "node:os";
|
|
import { join } from "node:path";
|
|
import { Agent } from "@mariozechner/pi-agent-core";
|
|
import {
|
|
type AssistantMessage,
|
|
type AssistantMessageEvent,
|
|
EventStream,
|
|
getModel,
|
|
type Model,
|
|
} from "@mariozechner/pi-ai";
|
|
import { afterEach, describe, expect, it, vi } from "vitest";
|
|
import { AgentSession } from "../src/core/agent-session.js";
|
|
import type { AgentSessionRuntime } from "../src/core/agent-session-runtime.js";
|
|
import { AuthStorage } from "../src/core/auth-storage.js";
|
|
import { ModelRegistry } from "../src/core/model-registry.js";
|
|
import { SessionManager } from "../src/core/session-manager.js";
|
|
import { SettingsManager } from "../src/core/settings-manager.js";
|
|
import { runRpcMode } from "../src/modes/rpc/rpc-mode.js";
|
|
import { createTestResourceLoader } from "./utilities.js";
|
|
|
|
const rpcIo = vi.hoisted(() => ({
|
|
outputLines: [] as string[],
|
|
lineHandler: undefined as ((line: string) => void) | undefined,
|
|
}));
|
|
|
|
vi.mock("../src/core/output-guard.js", () => ({
|
|
takeOverStdout: vi.fn(),
|
|
writeRawStdout: (line: string) => {
|
|
rpcIo.outputLines.push(line);
|
|
},
|
|
}));
|
|
|
|
vi.mock("../src/modes/interactive/theme/theme.js", () => ({ theme: {} }));
|
|
|
|
vi.mock("../src/modes/rpc/jsonl.js", () => ({
|
|
attachJsonlLineReader: vi.fn((_stream: NodeJS.ReadableStream, onLine: (line: string) => void) => {
|
|
rpcIo.lineHandler = onLine;
|
|
return () => {};
|
|
}),
|
|
serializeJsonLine: (value: unknown) => `${JSON.stringify(value)}\n`,
|
|
}));
|
|
|
|
class MockAssistantStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
|
constructor() {
|
|
super(
|
|
(event) => event.type === "done" || event.type === "error",
|
|
(event) => {
|
|
if (event.type === "done") return event.message;
|
|
if (event.type === "error") return event.error;
|
|
throw new Error("Unexpected event type");
|
|
},
|
|
);
|
|
}
|
|
}
|
|
|
|
function createAssistantMessage(text: string): AssistantMessage {
|
|
return {
|
|
role: "assistant",
|
|
content: [{ type: "text", text }],
|
|
api: "anthropic-messages",
|
|
provider: "anthropic",
|
|
model: "claude-sonnet-4-5",
|
|
usage: {
|
|
input: 0,
|
|
output: 0,
|
|
cacheRead: 0,
|
|
cacheWrite: 0,
|
|
totalTokens: 0,
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
|
},
|
|
stopReason: "stop",
|
|
timestamp: Date.now(),
|
|
};
|
|
}
|
|
|
|
type ParsedOutputLine = Record<string, unknown>;
|
|
|
|
function parseOutputLines(outputLines: string[]): ParsedOutputLine[] {
|
|
return outputLines
|
|
.flatMap((line) => line.split("\n"))
|
|
.filter((line) => line.trim().length > 0)
|
|
.map((line) => JSON.parse(line) as ParsedOutputLine);
|
|
}
|
|
|
|
function getPromptResponses(outputLines: string[], id: string): ParsedOutputLine[] {
|
|
return parseOutputLines(outputLines).filter(
|
|
(record) => record.id === id && record.type === "response" && record.command === "prompt",
|
|
);
|
|
}
|
|
|
|
function sleep(ms: number): Promise<void> {
|
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
}
|
|
|
|
function createRuntimeHost(options: { withAuth: boolean; responseDelayMs: number; model?: Model<any> }): {
|
|
runtimeHost: AgentSessionRuntime;
|
|
cleanup: () => Promise<void>;
|
|
} {
|
|
const tempDir = join(tmpdir(), `pi-rpc-prompt-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
|
mkdirSync(tempDir, { recursive: true });
|
|
|
|
const model = options.model ?? getModel("anthropic", "claude-sonnet-4-5");
|
|
if (!model) {
|
|
throw new Error("Test model not found");
|
|
}
|
|
|
|
const agent = new Agent({
|
|
getApiKey: () => "test-key",
|
|
initialState: {
|
|
model,
|
|
systemPrompt: "Test",
|
|
tools: [],
|
|
},
|
|
streamFn: (_model, _context, _options) => {
|
|
const stream = new MockAssistantStream();
|
|
queueMicrotask(() => {
|
|
stream.push({ type: "start", partial: createAssistantMessage("") });
|
|
setTimeout(() => {
|
|
stream.push({ type: "done", reason: "stop", message: createAssistantMessage("done") });
|
|
}, options.responseDelayMs);
|
|
});
|
|
return stream;
|
|
},
|
|
});
|
|
|
|
const sessionManager = SessionManager.inMemory();
|
|
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
|
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
|
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
|
if (options.withAuth) {
|
|
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
|
}
|
|
|
|
const session = new AgentSession({
|
|
agent,
|
|
sessionManager,
|
|
settingsManager,
|
|
cwd: tempDir,
|
|
modelRegistry,
|
|
resourceLoader: createTestResourceLoader(),
|
|
});
|
|
|
|
const runtimeHost = {
|
|
session,
|
|
newSession: vi.fn(async () => ({ cancelled: true })),
|
|
switchSession: vi.fn(async () => ({ cancelled: true })),
|
|
fork: vi.fn(async () => ({ cancelled: true, selectedText: "" })),
|
|
dispose: vi.fn(async () => {}),
|
|
} as unknown as AgentSessionRuntime;
|
|
|
|
return {
|
|
runtimeHost,
|
|
cleanup: async () => {
|
|
try {
|
|
if (session.isStreaming) {
|
|
await session.abort();
|
|
}
|
|
} catch {
|
|
// ignore test cleanup failures
|
|
}
|
|
session.dispose();
|
|
if (existsSync(tempDir)) {
|
|
rmSync(tempDir, { recursive: true });
|
|
}
|
|
},
|
|
};
|
|
}
|
|
|
|
async function startRpcMode(options: { withAuth: boolean; responseDelayMs: number; model?: Model<any> }): Promise<{
|
|
lineHandler: (line: string) => void;
|
|
cleanup: () => Promise<void>;
|
|
}> {
|
|
rpcIo.outputLines = [];
|
|
rpcIo.lineHandler = undefined;
|
|
|
|
const { runtimeHost, cleanup } = createRuntimeHost(options);
|
|
void runRpcMode(runtimeHost);
|
|
await vi.waitFor(() => expect(rpcIo.lineHandler).toBeDefined());
|
|
|
|
return { lineHandler: rpcIo.lineHandler!, cleanup };
|
|
}
|
|
|
|
describe("RPC prompt response semantics", () => {
|
|
afterEach(() => {
|
|
rpcIo.outputLines = [];
|
|
rpcIo.lineHandler = undefined;
|
|
});
|
|
|
|
it("emits one failure response when prompt preflight rejects", async () => {
|
|
const { lineHandler, cleanup } = await startRpcMode({
|
|
withAuth: false,
|
|
responseDelayMs: 0,
|
|
model: {
|
|
id: "fake-model",
|
|
name: "Fake Model",
|
|
api: "openai-completions",
|
|
provider: "fake-provider",
|
|
baseUrl: "https://example.invalid",
|
|
reasoning: false,
|
|
input: [],
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
|
contextWindow: 0,
|
|
maxTokens: 0,
|
|
},
|
|
});
|
|
|
|
try {
|
|
lineHandler(JSON.stringify({ id: "b1", type: "prompt", message: "Hello" }));
|
|
|
|
await vi.waitFor(() => {
|
|
const responses = getPromptResponses(rpcIo.outputLines, "b1");
|
|
expect(responses).toHaveLength(1);
|
|
expect(responses[0]).toMatchObject({
|
|
id: "b1",
|
|
type: "response",
|
|
command: "prompt",
|
|
success: false,
|
|
error: expect.stringContaining(
|
|
"No API key found for fake-provider.\n\nUse /login or set an API key environment variable. See ",
|
|
),
|
|
});
|
|
});
|
|
} finally {
|
|
await cleanup();
|
|
}
|
|
});
|
|
|
|
it("emits one success response when prompt preflight succeeds", async () => {
|
|
const { lineHandler, cleanup } = await startRpcMode({ withAuth: true, responseDelayMs: 0 });
|
|
|
|
try {
|
|
lineHandler(JSON.stringify({ id: "b2", type: "prompt", message: "Hello" }));
|
|
|
|
await vi.waitFor(() => {
|
|
const responses = getPromptResponses(rpcIo.outputLines, "b2");
|
|
expect(responses).toHaveLength(1);
|
|
expect(responses[0]).toMatchObject({
|
|
id: "b2",
|
|
type: "response",
|
|
command: "prompt",
|
|
success: true,
|
|
});
|
|
});
|
|
} finally {
|
|
await cleanup();
|
|
}
|
|
});
|
|
|
|
it("emits one success response when prompt is queued during streaming", async () => {
|
|
const { lineHandler, cleanup } = await startRpcMode({ withAuth: true, responseDelayMs: 100 });
|
|
|
|
try {
|
|
lineHandler(JSON.stringify({ id: "b3-start", type: "prompt", message: "Start" }));
|
|
await vi.waitFor(() => {
|
|
expect(getPromptResponses(rpcIo.outputLines, "b3-start")).toHaveLength(1);
|
|
});
|
|
|
|
rpcIo.outputLines = [];
|
|
lineHandler(
|
|
JSON.stringify({
|
|
id: "b3",
|
|
type: "prompt",
|
|
message: "Queue this",
|
|
streamingBehavior: "followUp",
|
|
}),
|
|
);
|
|
|
|
await vi.waitFor(() => {
|
|
const responses = getPromptResponses(rpcIo.outputLines, "b3");
|
|
expect(responses).toHaveLength(1);
|
|
expect(responses[0]).toMatchObject({
|
|
id: "b3",
|
|
type: "response",
|
|
command: "prompt",
|
|
success: true,
|
|
});
|
|
});
|
|
|
|
await sleep(150);
|
|
} finally {
|
|
await cleanup();
|
|
}
|
|
});
|
|
});
|