mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-05-11 13:10:23 +00:00
feat(mcp): add reconnect command and implement auto-reconnect logic
- Added a new reconnect command to the MCP CLI. - Implemented auto-reconnect functionality in DiscoveredMCPToolInvocation to handle connection errors with retry logic. - Enhanced tests to cover reconnect scenarios and ensure reliability during connection failures.
This commit is contained in:
parent
9391779cd0
commit
8a2bda67ed
6 changed files with 874 additions and 95 deletions
|
|
@ -9,6 +9,7 @@ import type { CommandModule, Argv } from 'yargs';
|
|||
import { addCommand } from './mcp/add.js';
|
||||
import { removeCommand } from './mcp/remove.js';
|
||||
import { listCommand } from './mcp/list.js';
|
||||
import { reconnectCommand } from './mcp/reconnect.js';
|
||||
|
||||
export const mcpCommand: CommandModule = {
|
||||
command: 'mcp',
|
||||
|
|
@ -18,6 +19,7 @@ export const mcpCommand: CommandModule = {
|
|||
.command(addCommand)
|
||||
.command(removeCommand)
|
||||
.command(listCommand)
|
||||
.command(reconnectCommand)
|
||||
.demandCommand(1, 'You need at least one command before continuing.')
|
||||
.version(false),
|
||||
handler: () => {
|
||||
|
|
|
|||
235
packages/cli/src/commands/mcp/reconnect.test.ts
Normal file
235
packages/cli/src/commands/mcp/reconnect.test.ts
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { reconnectCommand } from './reconnect.js';
|
||||
import { loadSettings } from '../../config/settings.js';
|
||||
import { Config, ExtensionManager } from '@qwen-code/qwen-code-core';
|
||||
|
||||
const mockWriteStdoutLine = vi.hoisted(() => vi.fn());
|
||||
const mockWriteStderrLine = vi.hoisted(() => vi.fn());
|
||||
const mockProcessExit = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('../../utils/stdioHelpers.js', () => ({
|
||||
writeStdoutLine: mockWriteStdoutLine,
|
||||
writeStderrLine: mockWriteStderrLine,
|
||||
}));
|
||||
|
||||
vi.mock('../../config/settings.js', () => ({
|
||||
loadSettings: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../config/trustedFolders.js', () => ({
|
||||
isWorkspaceTrusted: vi.fn().mockReturnValue(true),
|
||||
}));
|
||||
|
||||
vi.mock('@qwen-code/qwen-code-core', () => ({
|
||||
Config: vi.fn(),
|
||||
FileDiscoveryService: vi.fn(),
|
||||
ExtensionManager: vi.fn(),
|
||||
getErrorMessage: (e: unknown) => (e instanceof Error ? e.message : String(e)),
|
||||
}));
|
||||
|
||||
const mockedLoadSettings = loadSettings as vi.Mock;
|
||||
const MockedConfig = Config as vi.Mock;
|
||||
const MockedExtensionManager = ExtensionManager as vi.Mock;
|
||||
|
||||
describe('mcp reconnect command', () => {
|
||||
let mockConfig: {
|
||||
getToolRegistry: vi.Mock;
|
||||
shutdown: vi.Mock;
|
||||
initialize: vi.Mock;
|
||||
};
|
||||
let mockToolRegistry: {
|
||||
discoverToolsForServer: vi.Mock;
|
||||
};
|
||||
let mockExtensionManager: {
|
||||
refreshCache: vi.Mock;
|
||||
getLoadedExtensions: vi.Mock;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockWriteStdoutLine.mockClear();
|
||||
mockWriteStderrLine.mockClear();
|
||||
|
||||
mockToolRegistry = {
|
||||
discoverToolsForServer: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
mockConfig = {
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
shutdown: vi.fn().mockResolvedValue(undefined),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
mockExtensionManager = {
|
||||
refreshCache: vi.fn().mockResolvedValue(undefined),
|
||||
getLoadedExtensions: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
|
||||
MockedConfig.mockImplementation(() => mockConfig);
|
||||
MockedExtensionManager.mockImplementation(() => mockExtensionManager);
|
||||
|
||||
Object.defineProperty(process, 'exit', {
|
||||
value: mockProcessExit,
|
||||
writable: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('reconnect specific server', () => {
|
||||
it('should successfully reconnect a specific server', async () => {
|
||||
mockedLoadSettings.mockReturnValue({
|
||||
merged: {
|
||||
mcpServers: {
|
||||
'test-server': { command: '/path/to/server' },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const handler = reconnectCommand.handler as (
|
||||
argv: Record<string, unknown>,
|
||||
) => Promise<void>;
|
||||
await handler({ 'server-name': 'test-server', all: false });
|
||||
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'Reconnecting to server "test-server"...',
|
||||
);
|
||||
expect(mockToolRegistry.discoverToolsForServer).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'Successfully reconnected to server "test-server".',
|
||||
);
|
||||
});
|
||||
|
||||
it('should print error when server not found', async () => {
|
||||
mockedLoadSettings.mockReturnValue({
|
||||
merged: {
|
||||
mcpServers: {
|
||||
'other-server': { command: '/path/to/server' },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const handler = reconnectCommand.handler as (
|
||||
argv: Record<string, unknown>,
|
||||
) => Promise<void>;
|
||||
await handler({ 'server-name': 'nonexistent-server', all: false });
|
||||
|
||||
expect(mockWriteStderrLine).toHaveBeenCalledWith(
|
||||
'Error: Server "nonexistent-server" not found in configuration.',
|
||||
);
|
||||
expect(mockProcessExit).toHaveBeenCalledWith(1);
|
||||
});
|
||||
|
||||
it('should print error when reconnection fails', async () => {
|
||||
mockedLoadSettings.mockReturnValue({
|
||||
merged: {
|
||||
mcpServers: {
|
||||
'test-server': { command: '/path/to/server' },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
mockToolRegistry.discoverToolsForServer.mockRejectedValue(
|
||||
new Error('Connection refused'),
|
||||
);
|
||||
|
||||
const handler = reconnectCommand.handler as (
|
||||
argv: Record<string, unknown>,
|
||||
) => Promise<void>;
|
||||
await handler({ 'server-name': 'test-server', all: false });
|
||||
|
||||
expect(mockWriteStderrLine).toHaveBeenCalledWith(
|
||||
'Failed to reconnect to server "test-server": Connection refused',
|
||||
);
|
||||
expect(mockProcessExit).toHaveBeenCalledWith(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnect all servers', () => {
|
||||
it('should successfully reconnect all servers', async () => {
|
||||
mockedLoadSettings.mockReturnValue({
|
||||
merged: {
|
||||
mcpServers: {
|
||||
'server-one': { command: '/path/to/server1' },
|
||||
'server-two': { command: '/path/to/server2' },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const handler = reconnectCommand.handler as (
|
||||
argv: Record<string, unknown>,
|
||||
) => Promise<void>;
|
||||
await handler({ 'server-name': undefined, all: true });
|
||||
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'Reconnecting to all MCP servers...\n',
|
||||
);
|
||||
expect(mockToolRegistry.discoverToolsForServer).toHaveBeenCalledWith(
|
||||
'server-one',
|
||||
);
|
||||
expect(mockToolRegistry.discoverToolsForServer).toHaveBeenCalledWith(
|
||||
'server-two',
|
||||
);
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'✓ server-one: Reconnected successfully',
|
||||
);
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'✓ server-two: Reconnected successfully',
|
||||
);
|
||||
});
|
||||
|
||||
it('should print message when no servers configured', async () => {
|
||||
mockedLoadSettings.mockReturnValue({
|
||||
merged: {
|
||||
mcpServers: {},
|
||||
},
|
||||
});
|
||||
|
||||
const handler = reconnectCommand.handler as (
|
||||
argv: Record<string, unknown>,
|
||||
) => Promise<void>;
|
||||
await handler({ 'server-name': undefined, all: true });
|
||||
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'No MCP servers configured.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should report failure for individual servers when reconnecting all', async () => {
|
||||
mockedLoadSettings.mockReturnValue({
|
||||
merged: {
|
||||
mcpServers: {
|
||||
'server-one': { command: '/path/to/server1' },
|
||||
'server-two': { command: '/path/to/server2' },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
mockToolRegistry.discoverToolsForServer
|
||||
.mockResolvedValueOnce(undefined)
|
||||
.mockRejectedValueOnce(new Error('Timeout'));
|
||||
|
||||
const handler = reconnectCommand.handler as (
|
||||
argv: Record<string, unknown>,
|
||||
) => Promise<void>;
|
||||
await handler({ 'server-name': undefined, all: true });
|
||||
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'✓ server-one: Reconnected successfully',
|
||||
);
|
||||
expect(mockWriteStdoutLine).toHaveBeenCalledWith(
|
||||
'✗ server-two: Failed - Timeout',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
163
packages/cli/src/commands/mcp/reconnect.ts
Normal file
163
packages/cli/src/commands/mcp/reconnect.ts
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { CommandModule } from 'yargs';
|
||||
import { loadSettings } from '../../config/settings.js';
|
||||
import { writeStdoutLine, writeStderrLine } from '../../utils/stdioHelpers.js';
|
||||
import {
|
||||
Config,
|
||||
FileDiscoveryService,
|
||||
ExtensionManager,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { isWorkspaceTrusted } from '../../config/trustedFolders.js';
|
||||
import type { MCPServerConfig } from '@qwen-code/qwen-code-core';
|
||||
|
||||
async function getMcpServersFromConfig(): Promise<
|
||||
Record<string, MCPServerConfig>
|
||||
> {
|
||||
const settings = loadSettings();
|
||||
const extensionManager = new ExtensionManager({
|
||||
isWorkspaceTrusted: !!isWorkspaceTrusted(settings.merged),
|
||||
telemetrySettings: settings.merged.telemetry,
|
||||
});
|
||||
await extensionManager.refreshCache();
|
||||
const extensions = extensionManager.getLoadedExtensions();
|
||||
const mcpServers = { ...(settings.merged.mcpServers || {}) };
|
||||
for (const extension of extensions) {
|
||||
if (extension.isActive) {
|
||||
Object.entries(extension.config.mcpServers || {}).forEach(
|
||||
([key, server]) => {
|
||||
if (mcpServers[key]) {
|
||||
return;
|
||||
}
|
||||
mcpServers[key] = {
|
||||
...server,
|
||||
extensionName: extension.config.name,
|
||||
};
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
async function createMinimalConfig(): Promise<Config> {
|
||||
const settings = loadSettings();
|
||||
const cwd = process.cwd();
|
||||
const fileService = new FileDiscoveryService(cwd);
|
||||
|
||||
const config = new Config({
|
||||
sessionId: 'mcp-reconnect',
|
||||
targetDir: cwd,
|
||||
cwd,
|
||||
debugMode: false,
|
||||
mcpServers: settings.merged.mcpServers || {},
|
||||
fileDiscoveryService: fileService,
|
||||
mcpServerCommand: settings.merged.mcp?.serverCommand,
|
||||
});
|
||||
|
||||
await config.initialize();
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
async function reconnectMcpServer(serverName: string): Promise<void> {
|
||||
const mcpServers = await getMcpServersFromConfig();
|
||||
|
||||
if (!mcpServers[serverName]) {
|
||||
writeStderrLine(
|
||||
`Error: Server "${serverName}" not found in configuration.`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
writeStdoutLine(`Reconnecting to server "${serverName}"...`);
|
||||
|
||||
try {
|
||||
const config = await createMinimalConfig();
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
await toolRegistry.discoverToolsForServer(serverName);
|
||||
writeStdoutLine(`Successfully reconnected to server "${serverName}".`);
|
||||
await config.shutdown();
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
writeStderrLine(
|
||||
`Failed to reconnect to server "${serverName}": ${message}`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
async function reconnectAllMcpServers(): Promise<void> {
|
||||
const mcpServers = await getMcpServersFromConfig();
|
||||
const serverNames = Object.keys(mcpServers);
|
||||
|
||||
if (serverNames.length === 0) {
|
||||
writeStdoutLine('No MCP servers configured.');
|
||||
return;
|
||||
}
|
||||
|
||||
writeStdoutLine('Reconnecting to all MCP servers...\n');
|
||||
|
||||
let config: Config | undefined;
|
||||
try {
|
||||
config = await createMinimalConfig();
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
|
||||
for (const serverName of serverNames) {
|
||||
try {
|
||||
await toolRegistry.discoverToolsForServer(serverName);
|
||||
writeStdoutLine(`✓ ${serverName}: Reconnected successfully`);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
writeStdoutLine(`✗ ${serverName}: Failed - ${message}`);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (config) {
|
||||
await config.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const reconnectCommand: CommandModule = {
|
||||
command: 'reconnect [server-name]',
|
||||
describe: 'Reconnect MCP server(s)',
|
||||
builder: (yargs) =>
|
||||
yargs
|
||||
.usage('Usage: qwen mcp reconnect [options] [server-name]')
|
||||
.positional('server-name', {
|
||||
describe: 'Name of the server to reconnect',
|
||||
type: 'string',
|
||||
})
|
||||
.option('all', {
|
||||
alias: 'a',
|
||||
describe: 'Reconnect all configured servers',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
})
|
||||
.conflicts('server-name', 'all')
|
||||
.check((argv) => {
|
||||
const serverName = argv['server-name'];
|
||||
const all = argv['all'];
|
||||
if (!serverName && !all) {
|
||||
throw new Error(
|
||||
'Please specify a server name or use --all to reconnect all servers.',
|
||||
);
|
||||
}
|
||||
return true;
|
||||
}),
|
||||
handler: async (argv) => {
|
||||
const serverName = argv['server-name'] as string | undefined;
|
||||
const all = argv['all'] as boolean;
|
||||
|
||||
if (all) {
|
||||
await reconnectAllMcpServers();
|
||||
} else if (serverName) {
|
||||
await reconnectMcpServer(serverName);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -23,6 +23,7 @@
|
|||
"src/commands/mcp/add.test.ts",
|
||||
"src/commands/mcp/list.test.ts",
|
||||
"src/commands/mcp/remove.test.ts",
|
||||
"src/commands/mcp/reconnect.test.ts",
|
||||
"src/config/config.integration.test.ts",
|
||||
"src/config/config.test.ts",
|
||||
"src/config/extension.test.ts",
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import type { ToolResult } from './tools.js';
|
|||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
import type { CallableTool, Part } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { updateMCPServerStatus, MCPServerStatus } from './mcp-client.js';
|
||||
|
||||
// Mock @google/genai mcpToTool and CallableTool
|
||||
// We only need to mock the parts of CallableTool that DiscoveredMCPTool uses.
|
||||
|
|
@ -1116,4 +1117,284 @@ describe('DiscoveredMCPTool', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('auto-reconnect on connection error', () => {
|
||||
it('should attempt reconnect and retry on connection error', async () => {
|
||||
const params = { param: 'test' };
|
||||
const mockMcpClient: McpDirectClient = {
|
||||
callTool: vi.fn(),
|
||||
};
|
||||
|
||||
const successResult = {
|
||||
content: [{ type: 'text', text: 'Success after reconnect' }],
|
||||
};
|
||||
|
||||
const newMockMcpClient: McpDirectClient = {
|
||||
callTool: vi.fn().mockResolvedValueOnce(successResult),
|
||||
};
|
||||
|
||||
const newTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
newMockMcpClient,
|
||||
);
|
||||
|
||||
const discoverToolsForServer = vi.fn().mockResolvedValue(undefined);
|
||||
const getTool = vi.fn().mockReturnValue(newTool);
|
||||
const mockConfig = {
|
||||
isTrustedFolder: () => true,
|
||||
getToolRegistry: () => ({
|
||||
discoverToolsForServer,
|
||||
getTool,
|
||||
}),
|
||||
};
|
||||
|
||||
const connectionError = new Error('Connection closed');
|
||||
|
||||
(mockMcpClient.callTool as any).mockRejectedValueOnce(connectionError);
|
||||
|
||||
const reconnectTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig as any,
|
||||
mockMcpClient,
|
||||
);
|
||||
|
||||
const invocation = reconnectTool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(mockMcpClient.callTool).toHaveBeenCalledTimes(1);
|
||||
expect(newMockMcpClient.callTool).toHaveBeenCalledTimes(1);
|
||||
expect(discoverToolsForServer).toHaveBeenCalledWith(serverName);
|
||||
expect(result.llmContent).toEqual([{ text: 'Success after reconnect' }]);
|
||||
});
|
||||
|
||||
it('should not retry on non-connection errors', async () => {
|
||||
const params = { param: 'test' };
|
||||
const mockMcpClient: McpDirectClient = {
|
||||
callTool: vi.fn(),
|
||||
};
|
||||
|
||||
const discoverToolsForServer = vi.fn().mockResolvedValue(undefined);
|
||||
const mockConfig = {
|
||||
isTrustedFolder: () => true,
|
||||
getToolRegistry: () => ({
|
||||
discoverToolsForServer,
|
||||
getTool: vi.fn().mockReturnValue(null),
|
||||
}),
|
||||
};
|
||||
|
||||
const toolError = new Error('Invalid parameters');
|
||||
(mockMcpClient.callTool as any).mockRejectedValue(toolError);
|
||||
|
||||
const reconnectTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig as any,
|
||||
mockMcpClient,
|
||||
);
|
||||
|
||||
const invocation = reconnectTool.build(params);
|
||||
await expect(
|
||||
invocation.execute(new AbortController().signal),
|
||||
).rejects.toThrow('Invalid parameters');
|
||||
|
||||
expect(mockMcpClient.callTool).toHaveBeenCalledTimes(1);
|
||||
expect(discoverToolsForServer).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not retry more than once', async () => {
|
||||
const params = { param: 'test' };
|
||||
const mockMcpClient: McpDirectClient = {
|
||||
callTool: vi.fn(),
|
||||
};
|
||||
|
||||
const secondMockMcpClient: McpDirectClient = {
|
||||
callTool: vi.fn().mockRejectedValue(new Error('ECONNREFUSED')),
|
||||
};
|
||||
|
||||
const secondTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
secondMockMcpClient,
|
||||
);
|
||||
|
||||
const discoverToolsForServer = vi.fn().mockResolvedValue(undefined);
|
||||
const mockConfig = {
|
||||
isTrustedFolder: () => true,
|
||||
getToolRegistry: () => ({
|
||||
discoverToolsForServer,
|
||||
getTool: vi.fn().mockReturnValue(secondTool),
|
||||
}),
|
||||
};
|
||||
|
||||
const connectionError = new Error('ECONNREFUSED');
|
||||
(mockMcpClient.callTool as any).mockRejectedValue(connectionError);
|
||||
|
||||
const reconnectTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig as any,
|
||||
mockMcpClient,
|
||||
);
|
||||
|
||||
const invocation = reconnectTool.build(params);
|
||||
await expect(
|
||||
invocation.execute(new AbortController().signal),
|
||||
).rejects.toThrow('ECONNREFUSED');
|
||||
|
||||
expect(mockMcpClient.callTool).toHaveBeenCalledTimes(1);
|
||||
expect(secondMockMcpClient.callTool).toHaveBeenCalledTimes(1);
|
||||
expect(discoverToolsForServer).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should detect various connection error patterns', async () => {
|
||||
const connectionErrors = [
|
||||
'ECONNREFUSED',
|
||||
'ENOTFOUND',
|
||||
'ECONNRESET',
|
||||
'ETIMEDOUT',
|
||||
'connection closed',
|
||||
'Connection lost',
|
||||
'Not connected',
|
||||
'Disconnected',
|
||||
'Transport closed',
|
||||
];
|
||||
|
||||
for (const errorMsg of connectionErrors) {
|
||||
const params = { param: 'test' };
|
||||
const mockMcpClient: McpDirectClient = {
|
||||
callTool: vi.fn().mockRejectedValueOnce(new Error(errorMsg)),
|
||||
};
|
||||
|
||||
const newMockMcpClient: McpDirectClient = {
|
||||
callTool: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ content: [{ type: 'text', text: 'OK' }] }),
|
||||
};
|
||||
|
||||
const newTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
newMockMcpClient,
|
||||
);
|
||||
|
||||
const discoverToolsForServer = vi.fn().mockResolvedValue(undefined);
|
||||
const mockConfig = {
|
||||
isTrustedFolder: () => true,
|
||||
getToolRegistry: () => ({
|
||||
discoverToolsForServer,
|
||||
getTool: vi.fn().mockReturnValue(newTool),
|
||||
}),
|
||||
};
|
||||
|
||||
const reconnectTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig as any,
|
||||
mockMcpClient,
|
||||
);
|
||||
|
||||
const invocation = reconnectTool.build(params);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(discoverToolsForServer).toHaveBeenCalled();
|
||||
}
|
||||
});
|
||||
|
||||
it('should reconnect when MCP error occurs and server is disconnected', async () => {
|
||||
const params = { param: 'test' };
|
||||
const mockMcpClient: McpDirectClient = {
|
||||
callTool: vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(
|
||||
new Error('MCP error -32602: Invalid request'),
|
||||
),
|
||||
};
|
||||
|
||||
const newMockMcpClient: McpDirectClient = {
|
||||
callTool: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ content: [{ type: 'text', text: 'OK' }] }),
|
||||
};
|
||||
|
||||
const newTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
newMockMcpClient,
|
||||
);
|
||||
|
||||
const discoverToolsForServer = vi.fn().mockResolvedValue(undefined);
|
||||
const mockConfig = {
|
||||
isTrustedFolder: () => true,
|
||||
getToolRegistry: () => ({
|
||||
discoverToolsForServer,
|
||||
getTool: vi.fn().mockReturnValue(newTool),
|
||||
}),
|
||||
};
|
||||
|
||||
updateMCPServerStatus(serverName, MCPServerStatus.DISCONNECTED);
|
||||
|
||||
const reconnectTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig as any,
|
||||
mockMcpClient,
|
||||
);
|
||||
|
||||
const invocation = reconnectTool.build(params);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(discoverToolsForServer).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ import {
|
|||
import type { CallableTool, FunctionCall, Part } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { createDebugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const debugLogger = createDebugLogger('MCP_TOOL');
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
|
|
@ -111,6 +114,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
|||
ToolResult
|
||||
> {
|
||||
private static readonly allowlist: Set<string> = new Set();
|
||||
private static readonly MAX_RECONNECT_RETRIES = 3;
|
||||
|
||||
constructor(
|
||||
private readonly mcpTool: CallableTool,
|
||||
|
|
@ -123,6 +127,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
|||
private readonly mcpClient?: McpDirectClient,
|
||||
private readonly mcpTimeout?: number,
|
||||
private readonly annotations?: McpToolAnnotations,
|
||||
private readonly retryCount: number = 0,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
|
@ -192,6 +197,36 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
|||
return false;
|
||||
}
|
||||
|
||||
private async attemptReconnect(): Promise<DiscoveredMCPTool | null> {
|
||||
if (!this.cliConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
debugLogger.info(
|
||||
`Attempting to reconnect MCP server '${this.serverName}'...`,
|
||||
);
|
||||
const toolRegistry = this.cliConfig.getToolRegistry();
|
||||
await toolRegistry.discoverToolsForServer(this.serverName);
|
||||
|
||||
const newTool = toolRegistry.getTool(
|
||||
`mcp__${this.serverName}__${this.serverToolName}`,
|
||||
);
|
||||
if (newTool instanceof DiscoveredMCPTool) {
|
||||
debugLogger.info(
|
||||
`Successfully reconnected to MCP server '${this.serverName}'`,
|
||||
);
|
||||
return newTool;
|
||||
}
|
||||
return null;
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Failed to reconnect MCP server '${this.serverName}': ${error}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: ToolResultDisplay) => void,
|
||||
|
|
@ -214,60 +249,91 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
|||
signal: AbortSignal,
|
||||
updateOutput?: (output: ToolResultDisplay) => void,
|
||||
): Promise<ToolResult> {
|
||||
const callToolResult = await this.mcpClient!.callTool(
|
||||
{
|
||||
name: this.serverToolName,
|
||||
arguments: this.params as Record<string, unknown>,
|
||||
},
|
||||
undefined,
|
||||
{
|
||||
onprogress: (progress) => {
|
||||
if (updateOutput) {
|
||||
const progressData: McpToolProgressData = {
|
||||
type: 'mcp_tool_progress',
|
||||
progress: progress.progress,
|
||||
...(progress.total != null && { total: progress.total }),
|
||||
...(progress.message != null && { message: progress.message }),
|
||||
};
|
||||
updateOutput(progressData);
|
||||
}
|
||||
try {
|
||||
const callToolResult = await this.mcpClient!.callTool(
|
||||
{
|
||||
name: this.serverToolName,
|
||||
arguments: this.params as Record<string, unknown>,
|
||||
},
|
||||
timeout: this.mcpTimeout,
|
||||
signal,
|
||||
},
|
||||
);
|
||||
undefined,
|
||||
{
|
||||
onprogress: (progress) => {
|
||||
if (updateOutput) {
|
||||
const progressData: McpToolProgressData = {
|
||||
type: 'mcp_tool_progress',
|
||||
progress: progress.progress,
|
||||
...(progress.total != null && { total: progress.total }),
|
||||
...(progress.message != null && { message: progress.message }),
|
||||
};
|
||||
updateOutput(progressData);
|
||||
}
|
||||
},
|
||||
timeout: this.mcpTimeout,
|
||||
signal,
|
||||
},
|
||||
);
|
||||
|
||||
// Wrap the raw CallToolResult into the Part[] format that the
|
||||
// existing transform/display functions expect.
|
||||
const rawResponseParts = wrapMcpCallToolResultAsParts(
|
||||
this.serverToolName,
|
||||
callToolResult,
|
||||
);
|
||||
// Wrap the raw CallToolResult into the Part[] format that the
|
||||
// existing transform/display functions expect.
|
||||
const rawResponseParts = wrapMcpCallToolResultAsParts(
|
||||
this.serverToolName,
|
||||
callToolResult,
|
||||
);
|
||||
|
||||
// Ensure the response is not an error
|
||||
if (this.isMCPToolError(rawResponseParts)) {
|
||||
const errorMessage = `MCP tool '${
|
||||
this.serverToolName
|
||||
}' reported tool error for function call: ${safeJsonStringify({
|
||||
name: this.serverToolName,
|
||||
args: this.params,
|
||||
})} with response: ${safeJsonStringify(rawResponseParts)}`;
|
||||
return {
|
||||
llmContent: errorMessage,
|
||||
returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.MCP_TOOL_ERROR,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const transformedParts = transformMcpContentToParts(rawResponseParts);
|
||||
|
||||
// Ensure the response is not an error
|
||||
if (this.isMCPToolError(rawResponseParts)) {
|
||||
const errorMessage = `MCP tool '${
|
||||
this.serverToolName
|
||||
}' reported tool error for function call: ${safeJsonStringify({
|
||||
name: this.serverToolName,
|
||||
args: this.params,
|
||||
})} with response: ${safeJsonStringify(rawResponseParts)}`;
|
||||
return {
|
||||
llmContent: errorMessage,
|
||||
returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.MCP_TOOL_ERROR,
|
||||
},
|
||||
llmContent: transformedParts,
|
||||
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
|
||||
};
|
||||
} catch (error) {
|
||||
debugLogger.error(`MCP server error '${this.serverName}': ${error}`);
|
||||
|
||||
// Attempt reconnection with retry limit
|
||||
if (this.retryCount < DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES) {
|
||||
const newTool = await this.attemptReconnect();
|
||||
if (newTool) {
|
||||
const newInvocation = new DiscoveredMCPToolInvocation(
|
||||
newTool['mcpTool'],
|
||||
this.serverName,
|
||||
this.serverToolName,
|
||||
this.displayName,
|
||||
this.trust,
|
||||
this.params,
|
||||
this.cliConfig,
|
||||
newTool['mcpClient'],
|
||||
this.mcpTimeout,
|
||||
this.annotations,
|
||||
this.retryCount + 1,
|
||||
);
|
||||
return newInvocation.execute(signal, updateOutput);
|
||||
}
|
||||
} else {
|
||||
debugLogger.error(
|
||||
`Max reconnection attempts (${DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES}) reached for MCP server '${this.serverName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
|
||||
const transformedParts = transformMcpContentToParts(rawResponseParts);
|
||||
|
||||
return {
|
||||
llmContent: transformedParts,
|
||||
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -285,59 +351,90 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
|||
];
|
||||
|
||||
// Race MCP tool call with abort signal to respect cancellation
|
||||
const rawResponseParts = await new Promise<Part[]>((resolve, reject) => {
|
||||
if (signal.aborted) {
|
||||
const error = new Error('Tool call aborted');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
return;
|
||||
try {
|
||||
const rawResponseParts = await new Promise<Part[]>((resolve, reject) => {
|
||||
if (signal.aborted) {
|
||||
const error = new Error('Tool call aborted');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
const onAbort = () => {
|
||||
cleanup();
|
||||
const error = new Error('Tool call aborted');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
};
|
||||
const cleanup = () => {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
|
||||
this.mcpTool
|
||||
.callTool(functionCalls)
|
||||
.then((res) => {
|
||||
cleanup();
|
||||
resolve(res);
|
||||
})
|
||||
.catch((err) => {
|
||||
cleanup();
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
|
||||
// Ensure the response is not an error
|
||||
if (this.isMCPToolError(rawResponseParts)) {
|
||||
const errorMessage = `MCP tool '${
|
||||
this.serverToolName
|
||||
}' reported tool error for function call: ${safeJsonStringify(
|
||||
functionCalls[0],
|
||||
)} with response: ${safeJsonStringify(rawResponseParts)}`;
|
||||
return {
|
||||
llmContent: errorMessage,
|
||||
returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.MCP_TOOL_ERROR,
|
||||
},
|
||||
};
|
||||
}
|
||||
const onAbort = () => {
|
||||
cleanup();
|
||||
const error = new Error('Tool call aborted');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
};
|
||||
const cleanup = () => {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
|
||||
this.mcpTool
|
||||
.callTool(functionCalls)
|
||||
.then((res) => {
|
||||
cleanup();
|
||||
resolve(res);
|
||||
})
|
||||
.catch((err) => {
|
||||
cleanup();
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
const transformedParts = transformMcpContentToParts(rawResponseParts);
|
||||
|
||||
// Ensure the response is not an error
|
||||
if (this.isMCPToolError(rawResponseParts)) {
|
||||
const errorMessage = `MCP tool '${
|
||||
this.serverToolName
|
||||
}' reported tool error for function call: ${safeJsonStringify(
|
||||
functionCalls[0],
|
||||
)} with response: ${safeJsonStringify(rawResponseParts)}`;
|
||||
return {
|
||||
llmContent: errorMessage,
|
||||
returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.MCP_TOOL_ERROR,
|
||||
},
|
||||
llmContent: transformedParts,
|
||||
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
|
||||
};
|
||||
} catch (error) {
|
||||
debugLogger.error(`MCP server error '${this.serverName}': ${error}`);
|
||||
|
||||
// Attempt reconnection with retry limit
|
||||
if (this.retryCount < DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES) {
|
||||
const newTool = await this.attemptReconnect();
|
||||
if (newTool) {
|
||||
const newInvocation = new DiscoveredMCPToolInvocation(
|
||||
newTool['mcpTool'],
|
||||
this.serverName,
|
||||
this.serverToolName,
|
||||
this.displayName,
|
||||
this.trust,
|
||||
this.params,
|
||||
this.cliConfig,
|
||||
newTool['mcpClient'],
|
||||
this.mcpTimeout,
|
||||
this.annotations,
|
||||
this.retryCount + 1,
|
||||
);
|
||||
return newInvocation.execute(signal);
|
||||
}
|
||||
} else {
|
||||
debugLogger.error(
|
||||
`Max reconnection attempts (${DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES}) reached for MCP server '${this.serverName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
|
||||
const transformedParts = transformMcpContentToParts(rawResponseParts);
|
||||
|
||||
return {
|
||||
llmContent: transformedParts,
|
||||
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
|
||||
};
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue