diff --git a/packages/cli/src/commands/mcp.ts b/packages/cli/src/commands/mcp.ts index 5e55286c1..1bb9e0314 100644 --- a/packages/cli/src/commands/mcp.ts +++ b/packages/cli/src/commands/mcp.ts @@ -9,6 +9,7 @@ import type { CommandModule, Argv } from 'yargs'; import { addCommand } from './mcp/add.js'; import { removeCommand } from './mcp/remove.js'; import { listCommand } from './mcp/list.js'; +import { reconnectCommand } from './mcp/reconnect.js'; export const mcpCommand: CommandModule = { command: 'mcp', @@ -18,6 +19,7 @@ export const mcpCommand: CommandModule = { .command(addCommand) .command(removeCommand) .command(listCommand) + .command(reconnectCommand) .demandCommand(1, 'You need at least one command before continuing.') .version(false), handler: () => { diff --git a/packages/cli/src/commands/mcp/reconnect.test.ts b/packages/cli/src/commands/mcp/reconnect.test.ts new file mode 100644 index 000000000..eeb049004 --- /dev/null +++ b/packages/cli/src/commands/mcp/reconnect.test.ts @@ -0,0 +1,235 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { reconnectCommand } from './reconnect.js'; +import { loadSettings } from '../../config/settings.js'; +import { Config, ExtensionManager } from '@qwen-code/qwen-code-core'; + +const mockWriteStdoutLine = vi.hoisted(() => vi.fn()); +const mockWriteStderrLine = vi.hoisted(() => vi.fn()); +const mockProcessExit = vi.hoisted(() => vi.fn()); + +vi.mock('../../utils/stdioHelpers.js', () => ({ + writeStdoutLine: mockWriteStdoutLine, + writeStderrLine: mockWriteStderrLine, +})); + +vi.mock('../../config/settings.js', () => ({ + loadSettings: vi.fn(), +})); + +vi.mock('../../config/trustedFolders.js', () => ({ + isWorkspaceTrusted: vi.fn().mockReturnValue(true), +})); + +vi.mock('@qwen-code/qwen-code-core', () => ({ + Config: vi.fn(), + FileDiscoveryService: vi.fn(), + ExtensionManager: vi.fn(), + getErrorMessage: (e: unknown) => (e instanceof Error ? e.message : String(e)), +})); + +const mockedLoadSettings = loadSettings as vi.Mock; +const MockedConfig = Config as vi.Mock; +const MockedExtensionManager = ExtensionManager as vi.Mock; + +describe('mcp reconnect command', () => { + let mockConfig: { + getToolRegistry: vi.Mock; + shutdown: vi.Mock; + initialize: vi.Mock; + }; + let mockToolRegistry: { + discoverToolsForServer: vi.Mock; + }; + let mockExtensionManager: { + refreshCache: vi.Mock; + getLoadedExtensions: vi.Mock; + }; + + beforeEach(() => { + vi.resetAllMocks(); + mockWriteStdoutLine.mockClear(); + mockWriteStderrLine.mockClear(); + + mockToolRegistry = { + discoverToolsForServer: vi.fn().mockResolvedValue(undefined), + }; + + mockConfig = { + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + shutdown: vi.fn().mockResolvedValue(undefined), + initialize: vi.fn().mockResolvedValue(undefined), + }; + + mockExtensionManager = { + refreshCache: vi.fn().mockResolvedValue(undefined), + getLoadedExtensions: vi.fn().mockReturnValue([]), + }; + + MockedConfig.mockImplementation(() => mockConfig); + MockedExtensionManager.mockImplementation(() => mockExtensionManager); + + Object.defineProperty(process, 'exit', { + value: mockProcessExit, + writable: true, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('reconnect specific server', () => { + it('should successfully reconnect a specific server', async () => { + mockedLoadSettings.mockReturnValue({ + merged: { + mcpServers: { + 'test-server': { command: '/path/to/server' }, + }, + }, + }); + + const handler = reconnectCommand.handler as ( + argv: Record, + ) => Promise; + await handler({ 'server-name': 'test-server', all: false }); + + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + 'Reconnecting to server "test-server"...', + ); + expect(mockToolRegistry.discoverToolsForServer).toHaveBeenCalledWith( + 'test-server', + ); + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + 'Successfully reconnected to server "test-server".', + ); + }); + + it('should print error when server not found', async () => { + mockedLoadSettings.mockReturnValue({ + merged: { + mcpServers: { + 'other-server': { command: '/path/to/server' }, + }, + }, + }); + + const handler = reconnectCommand.handler as ( + argv: Record, + ) => Promise; + await handler({ 'server-name': 'nonexistent-server', all: false }); + + expect(mockWriteStderrLine).toHaveBeenCalledWith( + 'Error: Server "nonexistent-server" not found in configuration.', + ); + expect(mockProcessExit).toHaveBeenCalledWith(1); + }); + + it('should print error when reconnection fails', async () => { + mockedLoadSettings.mockReturnValue({ + merged: { + mcpServers: { + 'test-server': { command: '/path/to/server' }, + }, + }, + }); + + mockToolRegistry.discoverToolsForServer.mockRejectedValue( + new Error('Connection refused'), + ); + + const handler = reconnectCommand.handler as ( + argv: Record, + ) => Promise; + await handler({ 'server-name': 'test-server', all: false }); + + expect(mockWriteStderrLine).toHaveBeenCalledWith( + 'Failed to reconnect to server "test-server": Connection refused', + ); + expect(mockProcessExit).toHaveBeenCalledWith(1); + }); + }); + + describe('reconnect all servers', () => { + it('should successfully reconnect all servers', async () => { + mockedLoadSettings.mockReturnValue({ + merged: { + mcpServers: { + 'server-one': { command: '/path/to/server1' }, + 'server-two': { command: '/path/to/server2' }, + }, + }, + }); + + const handler = reconnectCommand.handler as ( + argv: Record, + ) => Promise; + await handler({ 'server-name': undefined, all: true }); + + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + 'Reconnecting to all MCP servers...\n', + ); + expect(mockToolRegistry.discoverToolsForServer).toHaveBeenCalledWith( + 'server-one', + ); + expect(mockToolRegistry.discoverToolsForServer).toHaveBeenCalledWith( + 'server-two', + ); + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + '✓ server-one: Reconnected successfully', + ); + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + '✓ server-two: Reconnected successfully', + ); + }); + + it('should print message when no servers configured', async () => { + mockedLoadSettings.mockReturnValue({ + merged: { + mcpServers: {}, + }, + }); + + const handler = reconnectCommand.handler as ( + argv: Record, + ) => Promise; + await handler({ 'server-name': undefined, all: true }); + + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + 'No MCP servers configured.', + ); + }); + + it('should report failure for individual servers when reconnecting all', async () => { + mockedLoadSettings.mockReturnValue({ + merged: { + mcpServers: { + 'server-one': { command: '/path/to/server1' }, + 'server-two': { command: '/path/to/server2' }, + }, + }, + }); + + mockToolRegistry.discoverToolsForServer + .mockResolvedValueOnce(undefined) + .mockRejectedValueOnce(new Error('Timeout')); + + const handler = reconnectCommand.handler as ( + argv: Record, + ) => Promise; + await handler({ 'server-name': undefined, all: true }); + + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + '✓ server-one: Reconnected successfully', + ); + expect(mockWriteStdoutLine).toHaveBeenCalledWith( + '✗ server-two: Failed - Timeout', + ); + }); + }); +}); diff --git a/packages/cli/src/commands/mcp/reconnect.ts b/packages/cli/src/commands/mcp/reconnect.ts new file mode 100644 index 000000000..b4d10cb71 --- /dev/null +++ b/packages/cli/src/commands/mcp/reconnect.ts @@ -0,0 +1,163 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { CommandModule } from 'yargs'; +import { loadSettings } from '../../config/settings.js'; +import { writeStdoutLine, writeStderrLine } from '../../utils/stdioHelpers.js'; +import { + Config, + FileDiscoveryService, + ExtensionManager, +} from '@qwen-code/qwen-code-core'; +import { isWorkspaceTrusted } from '../../config/trustedFolders.js'; +import type { MCPServerConfig } from '@qwen-code/qwen-code-core'; + +async function getMcpServersFromConfig(): Promise< + Record +> { + const settings = loadSettings(); + const extensionManager = new ExtensionManager({ + isWorkspaceTrusted: !!isWorkspaceTrusted(settings.merged), + telemetrySettings: settings.merged.telemetry, + }); + await extensionManager.refreshCache(); + const extensions = extensionManager.getLoadedExtensions(); + const mcpServers = { ...(settings.merged.mcpServers || {}) }; + for (const extension of extensions) { + if (extension.isActive) { + Object.entries(extension.config.mcpServers || {}).forEach( + ([key, server]) => { + if (mcpServers[key]) { + return; + } + mcpServers[key] = { + ...server, + extensionName: extension.config.name, + }; + }, + ); + } + } + return mcpServers; +} + +async function createMinimalConfig(): Promise { + const settings = loadSettings(); + const cwd = process.cwd(); + const fileService = new FileDiscoveryService(cwd); + + const config = new Config({ + sessionId: 'mcp-reconnect', + targetDir: cwd, + cwd, + debugMode: false, + mcpServers: settings.merged.mcpServers || {}, + fileDiscoveryService: fileService, + mcpServerCommand: settings.merged.mcp?.serverCommand, + }); + + await config.initialize(); + + return config; +} + +async function reconnectMcpServer(serverName: string): Promise { + const mcpServers = await getMcpServersFromConfig(); + + if (!mcpServers[serverName]) { + writeStderrLine( + `Error: Server "${serverName}" not found in configuration.`, + ); + process.exit(1); + } + + writeStdoutLine(`Reconnecting to server "${serverName}"...`); + + try { + const config = await createMinimalConfig(); + const toolRegistry = config.getToolRegistry(); + await toolRegistry.discoverToolsForServer(serverName); + writeStdoutLine(`Successfully reconnected to server "${serverName}".`); + await config.shutdown(); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + writeStderrLine( + `Failed to reconnect to server "${serverName}": ${message}`, + ); + process.exit(1); + } +} + +async function reconnectAllMcpServers(): Promise { + const mcpServers = await getMcpServersFromConfig(); + const serverNames = Object.keys(mcpServers); + + if (serverNames.length === 0) { + writeStdoutLine('No MCP servers configured.'); + return; + } + + writeStdoutLine('Reconnecting to all MCP servers...\n'); + + let config: Config | undefined; + try { + config = await createMinimalConfig(); + const toolRegistry = config.getToolRegistry(); + + for (const serverName of serverNames) { + try { + await toolRegistry.discoverToolsForServer(serverName); + writeStdoutLine(`✓ ${serverName}: Reconnected successfully`); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + writeStdoutLine(`✗ ${serverName}: Failed - ${message}`); + } + } + } finally { + if (config) { + await config.shutdown(); + } + } +} + +export const reconnectCommand: CommandModule = { + command: 'reconnect [server-name]', + describe: 'Reconnect MCP server(s)', + builder: (yargs) => + yargs + .usage('Usage: qwen mcp reconnect [options] [server-name]') + .positional('server-name', { + describe: 'Name of the server to reconnect', + type: 'string', + }) + .option('all', { + alias: 'a', + describe: 'Reconnect all configured servers', + type: 'boolean', + default: false, + }) + .conflicts('server-name', 'all') + .check((argv) => { + const serverName = argv['server-name']; + const all = argv['all']; + if (!serverName && !all) { + throw new Error( + 'Please specify a server name or use --all to reconnect all servers.', + ); + } + return true; + }), + handler: async (argv) => { + const serverName = argv['server-name'] as string | undefined; + const all = argv['all'] as boolean; + + if (all) { + await reconnectAllMcpServers(); + } else if (serverName) { + await reconnectMcpServer(serverName); + } + }, +}; diff --git a/packages/cli/tsconfig.json b/packages/cli/tsconfig.json index cd546eeda..9516949cd 100644 --- a/packages/cli/tsconfig.json +++ b/packages/cli/tsconfig.json @@ -23,6 +23,7 @@ "src/commands/mcp/add.test.ts", "src/commands/mcp/list.test.ts", "src/commands/mcp/remove.test.ts", + "src/commands/mcp/reconnect.test.ts", "src/config/config.integration.test.ts", "src/config/config.test.ts", "src/config/extension.test.ts", diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index 005623afe..7098dec66 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -17,6 +17,7 @@ import type { ToolResult } from './tools.js'; import { ToolConfirmationOutcome } from './tools.js'; import type { CallableTool, Part } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; +import { updateMCPServerStatus, MCPServerStatus } from './mcp-client.js'; // Mock @google/genai mcpToTool and CallableTool // We only need to mock the parts of CallableTool that DiscoveredMCPTool uses. @@ -1116,4 +1117,284 @@ describe('DiscoveredMCPTool', () => { }); }); }); + + describe('auto-reconnect on connection error', () => { + it('should attempt reconnect and retry on connection error', async () => { + const params = { param: 'test' }; + const mockMcpClient: McpDirectClient = { + callTool: vi.fn(), + }; + + const successResult = { + content: [{ type: 'text', text: 'Success after reconnect' }], + }; + + const newMockMcpClient: McpDirectClient = { + callTool: vi.fn().mockResolvedValueOnce(successResult), + }; + + const newTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + undefined, + newMockMcpClient, + ); + + const discoverToolsForServer = vi.fn().mockResolvedValue(undefined); + const getTool = vi.fn().mockReturnValue(newTool); + const mockConfig = { + isTrustedFolder: () => true, + getToolRegistry: () => ({ + discoverToolsForServer, + getTool, + }), + }; + + const connectionError = new Error('Connection closed'); + + (mockMcpClient.callTool as any).mockRejectedValueOnce(connectionError); + + const reconnectTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + mockConfig as any, + mockMcpClient, + ); + + const invocation = reconnectTool.build(params); + const result = await invocation.execute(new AbortController().signal); + + expect(mockMcpClient.callTool).toHaveBeenCalledTimes(1); + expect(newMockMcpClient.callTool).toHaveBeenCalledTimes(1); + expect(discoverToolsForServer).toHaveBeenCalledWith(serverName); + expect(result.llmContent).toEqual([{ text: 'Success after reconnect' }]); + }); + + it('should not retry on non-connection errors', async () => { + const params = { param: 'test' }; + const mockMcpClient: McpDirectClient = { + callTool: vi.fn(), + }; + + const discoverToolsForServer = vi.fn().mockResolvedValue(undefined); + const mockConfig = { + isTrustedFolder: () => true, + getToolRegistry: () => ({ + discoverToolsForServer, + getTool: vi.fn().mockReturnValue(null), + }), + }; + + const toolError = new Error('Invalid parameters'); + (mockMcpClient.callTool as any).mockRejectedValue(toolError); + + const reconnectTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + mockConfig as any, + mockMcpClient, + ); + + const invocation = reconnectTool.build(params); + await expect( + invocation.execute(new AbortController().signal), + ).rejects.toThrow('Invalid parameters'); + + expect(mockMcpClient.callTool).toHaveBeenCalledTimes(1); + expect(discoverToolsForServer).not.toHaveBeenCalled(); + }); + + it('should not retry more than once', async () => { + const params = { param: 'test' }; + const mockMcpClient: McpDirectClient = { + callTool: vi.fn(), + }; + + const secondMockMcpClient: McpDirectClient = { + callTool: vi.fn().mockRejectedValue(new Error('ECONNREFUSED')), + }; + + const secondTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + undefined, + secondMockMcpClient, + ); + + const discoverToolsForServer = vi.fn().mockResolvedValue(undefined); + const mockConfig = { + isTrustedFolder: () => true, + getToolRegistry: () => ({ + discoverToolsForServer, + getTool: vi.fn().mockReturnValue(secondTool), + }), + }; + + const connectionError = new Error('ECONNREFUSED'); + (mockMcpClient.callTool as any).mockRejectedValue(connectionError); + + const reconnectTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + mockConfig as any, + mockMcpClient, + ); + + const invocation = reconnectTool.build(params); + await expect( + invocation.execute(new AbortController().signal), + ).rejects.toThrow('ECONNREFUSED'); + + expect(mockMcpClient.callTool).toHaveBeenCalledTimes(1); + expect(secondMockMcpClient.callTool).toHaveBeenCalledTimes(1); + expect(discoverToolsForServer).toHaveBeenCalledTimes(1); + }); + + it('should detect various connection error patterns', async () => { + const connectionErrors = [ + 'ECONNREFUSED', + 'ENOTFOUND', + 'ECONNRESET', + 'ETIMEDOUT', + 'connection closed', + 'Connection lost', + 'Not connected', + 'Disconnected', + 'Transport closed', + ]; + + for (const errorMsg of connectionErrors) { + const params = { param: 'test' }; + const mockMcpClient: McpDirectClient = { + callTool: vi.fn().mockRejectedValueOnce(new Error(errorMsg)), + }; + + const newMockMcpClient: McpDirectClient = { + callTool: vi + .fn() + .mockResolvedValueOnce({ content: [{ type: 'text', text: 'OK' }] }), + }; + + const newTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + undefined, + newMockMcpClient, + ); + + const discoverToolsForServer = vi.fn().mockResolvedValue(undefined); + const mockConfig = { + isTrustedFolder: () => true, + getToolRegistry: () => ({ + discoverToolsForServer, + getTool: vi.fn().mockReturnValue(newTool), + }), + }; + + const reconnectTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + mockConfig as any, + mockMcpClient, + ); + + const invocation = reconnectTool.build(params); + await invocation.execute(new AbortController().signal); + + expect(discoverToolsForServer).toHaveBeenCalled(); + } + }); + + it('should reconnect when MCP error occurs and server is disconnected', async () => { + const params = { param: 'test' }; + const mockMcpClient: McpDirectClient = { + callTool: vi + .fn() + .mockRejectedValueOnce( + new Error('MCP error -32602: Invalid request'), + ), + }; + + const newMockMcpClient: McpDirectClient = { + callTool: vi + .fn() + .mockResolvedValueOnce({ content: [{ type: 'text', text: 'OK' }] }), + }; + + const newTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + undefined, + newMockMcpClient, + ); + + const discoverToolsForServer = vi.fn().mockResolvedValue(undefined); + const mockConfig = { + isTrustedFolder: () => true, + getToolRegistry: () => ({ + discoverToolsForServer, + getTool: vi.fn().mockReturnValue(newTool), + }), + }; + + updateMCPServerStatus(serverName, MCPServerStatus.DISCONNECTED); + + const reconnectTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + undefined, + mockConfig as any, + mockMcpClient, + ); + + const invocation = reconnectTool.build(params); + await invocation.execute(new AbortController().signal); + + expect(discoverToolsForServer).toHaveBeenCalled(); + }); + }); }); diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 5d48b68c7..b5f520238 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -23,6 +23,9 @@ import { import type { CallableTool, FunctionCall, Part } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; import type { Config } from '../config/config.js'; +import { createDebugLogger } from '../utils/debugLogger.js'; + +const debugLogger = createDebugLogger('MCP_TOOL'); type ToolParams = Record; @@ -111,6 +114,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< ToolResult > { private static readonly allowlist: Set = new Set(); + private static readonly MAX_RECONNECT_RETRIES = 3; constructor( private readonly mcpTool: CallableTool, @@ -123,6 +127,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< private readonly mcpClient?: McpDirectClient, private readonly mcpTimeout?: number, private readonly annotations?: McpToolAnnotations, + private readonly retryCount: number = 0, ) { super(params); } @@ -192,6 +197,36 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< return false; } + private async attemptReconnect(): Promise { + if (!this.cliConfig) { + return null; + } + + try { + debugLogger.info( + `Attempting to reconnect MCP server '${this.serverName}'...`, + ); + const toolRegistry = this.cliConfig.getToolRegistry(); + await toolRegistry.discoverToolsForServer(this.serverName); + + const newTool = toolRegistry.getTool( + `mcp__${this.serverName}__${this.serverToolName}`, + ); + if (newTool instanceof DiscoveredMCPTool) { + debugLogger.info( + `Successfully reconnected to MCP server '${this.serverName}'`, + ); + return newTool; + } + return null; + } catch (error) { + debugLogger.error( + `Failed to reconnect MCP server '${this.serverName}': ${error}`, + ); + return null; + } + } + async execute( signal: AbortSignal, updateOutput?: (output: ToolResultDisplay) => void, @@ -214,60 +249,91 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< signal: AbortSignal, updateOutput?: (output: ToolResultDisplay) => void, ): Promise { - const callToolResult = await this.mcpClient!.callTool( - { - name: this.serverToolName, - arguments: this.params as Record, - }, - undefined, - { - onprogress: (progress) => { - if (updateOutput) { - const progressData: McpToolProgressData = { - type: 'mcp_tool_progress', - progress: progress.progress, - ...(progress.total != null && { total: progress.total }), - ...(progress.message != null && { message: progress.message }), - }; - updateOutput(progressData); - } + try { + const callToolResult = await this.mcpClient!.callTool( + { + name: this.serverToolName, + arguments: this.params as Record, }, - timeout: this.mcpTimeout, - signal, - }, - ); + undefined, + { + onprogress: (progress) => { + if (updateOutput) { + const progressData: McpToolProgressData = { + type: 'mcp_tool_progress', + progress: progress.progress, + ...(progress.total != null && { total: progress.total }), + ...(progress.message != null && { message: progress.message }), + }; + updateOutput(progressData); + } + }, + timeout: this.mcpTimeout, + signal, + }, + ); - // Wrap the raw CallToolResult into the Part[] format that the - // existing transform/display functions expect. - const rawResponseParts = wrapMcpCallToolResultAsParts( - this.serverToolName, - callToolResult, - ); + // Wrap the raw CallToolResult into the Part[] format that the + // existing transform/display functions expect. + const rawResponseParts = wrapMcpCallToolResultAsParts( + this.serverToolName, + callToolResult, + ); + + // Ensure the response is not an error + if (this.isMCPToolError(rawResponseParts)) { + const errorMessage = `MCP tool '${ + this.serverToolName + }' reported tool error for function call: ${safeJsonStringify({ + name: this.serverToolName, + args: this.params, + })} with response: ${safeJsonStringify(rawResponseParts)}`; + return { + llmContent: errorMessage, + returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`, + error: { + message: errorMessage, + type: ToolErrorType.MCP_TOOL_ERROR, + }, + }; + } + + const transformedParts = transformMcpContentToParts(rawResponseParts); - // Ensure the response is not an error - if (this.isMCPToolError(rawResponseParts)) { - const errorMessage = `MCP tool '${ - this.serverToolName - }' reported tool error for function call: ${safeJsonStringify({ - name: this.serverToolName, - args: this.params, - })} with response: ${safeJsonStringify(rawResponseParts)}`; return { - llmContent: errorMessage, - returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`, - error: { - message: errorMessage, - type: ToolErrorType.MCP_TOOL_ERROR, - }, + llmContent: transformedParts, + returnDisplay: getStringifiedResultForDisplay(rawResponseParts), }; + } catch (error) { + debugLogger.error(`MCP server error '${this.serverName}': ${error}`); + + // Attempt reconnection with retry limit + if (this.retryCount < DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES) { + const newTool = await this.attemptReconnect(); + if (newTool) { + const newInvocation = new DiscoveredMCPToolInvocation( + newTool['mcpTool'], + this.serverName, + this.serverToolName, + this.displayName, + this.trust, + this.params, + this.cliConfig, + newTool['mcpClient'], + this.mcpTimeout, + this.annotations, + this.retryCount + 1, + ); + return newInvocation.execute(signal, updateOutput); + } + } else { + debugLogger.error( + `Max reconnection attempts (${DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES}) reached for MCP server '${this.serverName}'`, + ); + } + + throw error; } - - const transformedParts = transformMcpContentToParts(rawResponseParts); - - return { - llmContent: transformedParts, - returnDisplay: getStringifiedResultForDisplay(rawResponseParts), - }; } /** @@ -285,59 +351,90 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< ]; // Race MCP tool call with abort signal to respect cancellation - const rawResponseParts = await new Promise((resolve, reject) => { - if (signal.aborted) { - const error = new Error('Tool call aborted'); - error.name = 'AbortError'; - reject(error); - return; + try { + const rawResponseParts = await new Promise((resolve, reject) => { + if (signal.aborted) { + const error = new Error('Tool call aborted'); + error.name = 'AbortError'; + reject(error); + return; + } + const onAbort = () => { + cleanup(); + const error = new Error('Tool call aborted'); + error.name = 'AbortError'; + reject(error); + }; + const cleanup = () => { + signal.removeEventListener('abort', onAbort); + }; + signal.addEventListener('abort', onAbort, { once: true }); + + this.mcpTool + .callTool(functionCalls) + .then((res) => { + cleanup(); + resolve(res); + }) + .catch((err) => { + cleanup(); + reject(err); + }); + }); + + // Ensure the response is not an error + if (this.isMCPToolError(rawResponseParts)) { + const errorMessage = `MCP tool '${ + this.serverToolName + }' reported tool error for function call: ${safeJsonStringify( + functionCalls[0], + )} with response: ${safeJsonStringify(rawResponseParts)}`; + return { + llmContent: errorMessage, + returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`, + error: { + message: errorMessage, + type: ToolErrorType.MCP_TOOL_ERROR, + }, + }; } - const onAbort = () => { - cleanup(); - const error = new Error('Tool call aborted'); - error.name = 'AbortError'; - reject(error); - }; - const cleanup = () => { - signal.removeEventListener('abort', onAbort); - }; - signal.addEventListener('abort', onAbort, { once: true }); - this.mcpTool - .callTool(functionCalls) - .then((res) => { - cleanup(); - resolve(res); - }) - .catch((err) => { - cleanup(); - reject(err); - }); - }); + const transformedParts = transformMcpContentToParts(rawResponseParts); - // Ensure the response is not an error - if (this.isMCPToolError(rawResponseParts)) { - const errorMessage = `MCP tool '${ - this.serverToolName - }' reported tool error for function call: ${safeJsonStringify( - functionCalls[0], - )} with response: ${safeJsonStringify(rawResponseParts)}`; return { - llmContent: errorMessage, - returnDisplay: `Error: MCP tool '${this.serverToolName}' reported an error.`, - error: { - message: errorMessage, - type: ToolErrorType.MCP_TOOL_ERROR, - }, + llmContent: transformedParts, + returnDisplay: getStringifiedResultForDisplay(rawResponseParts), }; + } catch (error) { + debugLogger.error(`MCP server error '${this.serverName}': ${error}`); + + // Attempt reconnection with retry limit + if (this.retryCount < DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES) { + const newTool = await this.attemptReconnect(); + if (newTool) { + const newInvocation = new DiscoveredMCPToolInvocation( + newTool['mcpTool'], + this.serverName, + this.serverToolName, + this.displayName, + this.trust, + this.params, + this.cliConfig, + newTool['mcpClient'], + this.mcpTimeout, + this.annotations, + this.retryCount + 1, + ); + return newInvocation.execute(signal); + } + } else { + debugLogger.error( + `Max reconnection attempts (${DiscoveredMCPToolInvocation.MAX_RECONNECT_RETRIES}) reached for MCP server '${this.serverName}'`, + ); + } + + throw error; } - - const transformedParts = transformMcpContentToParts(rawResponseParts); - - return { - llmContent: transformedParts, - returnDisplay: getStringifiedResultForDisplay(rawResponseParts), - }; } getDescription(): string {