diff --git a/backend/app/utils/toolkit/notion_mcp_toolkit.py b/backend/app/utils/toolkit/notion_mcp_toolkit.py index 137c48cd1..36928aa0a 100644 --- a/backend/app/utils/toolkit/notion_mcp_toolkit.py +++ b/backend/app/utils/toolkit/notion_mcp_toolkit.py @@ -1,14 +1,13 @@ import os -from typing import Any, ClassVar, Dict, List, Optional, Set -from camel.toolkits import FunctionTool, NotionMCPToolkit as BaseNotionMCPToolkit -from app.component.command import bun +from typing import Any, Dict, List +from loguru import logger +from camel.toolkits import FunctionTool from app.component.environment import env -from app.service.task import Agents from app.utils.toolkit.abstract_toolkit import AbstractToolkit from camel.toolkits.mcp_toolkit import MCPToolkit -class NotionMCPToolkit(BaseNotionMCPToolkit, AbstractToolkit): +class NotionMCPToolkit(MCPToolkit, AbstractToolkit): def __init__( self, @@ -18,25 +17,84 @@ class NotionMCPToolkit(BaseNotionMCPToolkit, AbstractToolkit): self.api_task_id = api_task_id if timeout is None: timeout = 120.0 - super().__init__(timeout) - self._mcp_toolkit = MCPToolkit( - config_dict={ - "mcpServers": { - "notionMCP": { - "command": "npx", - "args": [ - "-y", - "mcp-remote", - "https://mcp.notion.com/mcp", - ], - "env": { - "MCP_REMOTE_CONFIG_DIR": env("MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth")), - }, - } - } - }, - timeout=timeout, - ) + + config_dict={ + "mcpServers": { + "notionMCP": { + "command": "npx", + "args": [ + "-y", + "mcp-remote", + "https://mcp.notion.com/mcp", + ], + "env": { + "MCP_REMOTE_CONFIG_DIR": env("MCP_REMOTE_CONFIG_DIR", os.path.expanduser("~/.mcp-auth")), + }, + } + } + } + super().__init__(config_dict=config_dict, timeout=timeout) + + def get_tools(self) -> List[FunctionTool]: + r"""Returns a list of tools provided by the NotionMCPToolkit. + + Returns: + List[FunctionTool]: List of available tools. + """ + all_tools = [] + for client in self.clients: + try: + original_build_schema = client._build_tool_schema + + def create_wrapper(orig_func): + def wrapper(mcp_tool): + return self._build_custom_tool_schema( + mcp_tool, orig_func + ) + + return wrapper + + client._build_tool_schema = create_wrapper( # type: ignore[method-assign] + original_build_schema + ) + + client_tools = client.get_tools() + all_tools.extend(client_tools) + + client._build_tool_schema = original_build_schema # type: ignore[method-assign] + + except Exception as e: + logger.error(f"Failed to get tools from client: {e}") + return all_tools + + def _build_custom_tool_schema(self, mcp_tool, original_build_schema): + r"""Build tool schema with custom modifications.""" + schema = original_build_schema(mcp_tool) + self._customize_function_parameters(schema) + return schema + + def _customize_function_parameters(self, schema: Dict[str, Any]) -> None: + r"""Customize function parameters for specific functions. + + This method allows modifying parameter descriptions or other schema + attributes for specific functions. + """ + function_info = schema.get("function", {}) + function_name = function_info.get("name", "") + parameters = function_info.get("parameters", {}) + properties = parameters.get("properties", {}) + + # Modify the notion-create-pages function to make parent optional + if function_name == "notion-create-pages": + if "parent" in properties: + # Update the parent parameter description + properties["parent"]["description"] = ( + "Optional. The parent under which the new pages will be created. " + "This can be a page (page_id), a database page (database_id), or " + "a data source/collection under a database (data_source_id). " + "If omitted, the new pages will be created as private pages at the workspace level. " + "Use data_source_id when you have a collection:// URL from the fetch tool." + ) @classmethod async def get_can_use_tools(cls, api_task_id: str) -> list[FunctionTool]: @@ -45,7 +103,7 @@ class NotionMCPToolkit(BaseNotionMCPToolkit, AbstractToolkit): try: await toolkit.connect() # Use subclass implementation that inlines upstream processing - all_tools = BaseNotionMCPToolkit.get_tools(toolkit) + all_tools = toolkit.get_tools() for item in all_tools: setattr(item, "_toolkit_name", cls.__name__) tools.append(item) diff --git a/src/components/AddWorker/IntegrationList.tsx b/src/components/AddWorker/IntegrationList.tsx index 027a1fbf0..8a2dc55d2 100644 --- a/src/components/AddWorker/IntegrationList.tsx +++ b/src/components/AddWorker/IntegrationList.tsx @@ -15,6 +15,7 @@ interface IntegrationItem { name: string; desc: string; env_vars: string[]; + toolkit?: string; // Add toolkit field onInstall: () => void | Promise; } @@ -316,8 +317,11 @@ export default function IntegrationList({ "Github", ].includes(item.name) ) { - if (item.env_vars.length === 0 || isInstalled) { - addOption(item, true); + if (item.env_vars.length === 0 || isInstalled) { + // Ensure toolkit field is passed and normalized for known cases + const normalizedToolkit = + item.name === "Notion" ? "notion_mcp_toolkit" : item.toolkit; + addOption({ ...item, toolkit: normalizedToolkit }, true); } else { handleInstall(item); }