diff --git a/koboldcpp.py b/koboldcpp.py index 001a2fb7c..c8ff715f3 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -114,7 +114,8 @@ has_audio_support = False has_vision_support = False cached_chat_template = None savedata_obj = None -mcp_connections = [] #every element is linked to one mcp source, contains obj {"process":optional_MCPStdioClient, "url":"optional_str", "tools":[]} +mcp_connections = [] #every element is linked to one mcp source, contains obj {"client":obj, "tools":[]} +mcp_lock = threading.Lock() multiplayer_story_data_compressed = None #stores the full compressed story of the current multiplayer session multiplayer_turn_major = 1 # to keep track of when a client needs to sync their stories multiplayer_turn_minor = 1 @@ -4102,7 +4103,7 @@ Change Mode
return def do_POST(self): - global modelbusy, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit + global modelbusy, requestsinqueue, currentusergenkey, totalgens, pendingabortkey, lastuploadedcomfyimg, lastgeneratedcomfyimg, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, net_save_slots, has_vision_support, savestate_limit, mcp_lock contlenstr = self.headers['content-length'] content_length = 0 body = None @@ -4470,12 +4471,38 @@ Change Mode
} response_body = (json.dumps(reply).encode()) elif method == "tools/list": - reply = {} + reply = { + "jsonrpc": "2.0", + "id": random.randint(100000, 999999), + "result": {"tools": []}, + } + with mcp_lock: + for conn in mcp_connections: + currtools = conn["tools"] + for tool in currtools: + reply["result"]["tools"].append(tool) response_body = (json.dumps(reply).encode()) + elif method == "tools/call": + foundtool = False + callparams = tempbody.get("params",{}) + callname = callparams.get("name","") + with mcp_lock: + for conn in mcp_connections: + currtools = conn["tools"] + currclient = conn["client"] + for tool in currtools: + if currclient and tool.get("name","")!="" and tool.get("name","")==callname: + foundtool = True + mcpresp = currclient.send(tempbody) + response_body = (json.dumps(mcpresp).encode()) + break + if not foundtool: + response_code = 400 + response_body = (json.dumps({"error": {"code": -32700, "message": "Tool not found"}}).encode()) else: #probably a notify, send empty response - self.send_response(200) - return - except Exception: + response_body = (json.dumps({}).encode()) + except Exception as e: + print(f"MCP Call Error: {e}") response_code = 400 response_body = (json.dumps({"error": {"code": -32700, "message": "Parse error"}}).encode()) @@ -7590,6 +7617,73 @@ def register_koboldcpp(): except Exception as e: print(f"Register Extensions: An error occurred: {e}") +def load_mcp_async(args): + global mcp_connections, mcp_lock + filepath = os.path.abspath(args.mcpfile) + if not filepath.lower().endswith(".json"): + filepath += ".json" + args.mcpfile += ".json" + try: + print(f"MCP start loading json file at '{filepath}'...") + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + loaded = json.load(f) + if not isinstance(loaded, dict): + raise ValueError("MCP config must be a JSON object") + servers = loaded.get("mcpServers") + if not isinstance(servers, dict): + raise ValueError("MCP config missing 'mcpServers' object") + for name, cfg in servers.items(): + try: + if not isinstance(cfg, dict): + raise ValueError(f"MCP server '{name}' must be an object") + mcpurl = cfg.get("url", "") + mcpcmd = cfg.get("command","") + if mcpcmd and not mcpurl: + mcpargs = cfg.get("args", []) + mcpenv = cfg.get("env", {}) + client = MCPStdioClient(command=mcpcmd,largs=mcpargs,env=mcpenv) + elif mcpurl: + headers = cfg.get("headers", {}) + client = MCPHTTPClient(url=mcpurl, headers=headers) + else: + raise ValueError(f"MCP server '{name}' missing 'command' and 'url'") + with mcp_lock: + mcp_connections.append({"client":client,"tools":[]}) + except Exception as e: + print(f"MCP Init Error: {e}") + for conn in list(mcp_connections): + try: + init_payload = { + "jsonrpc": "2.0", + "id": random.randint(100000, 999999), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "koboldcpp", "version": "1.0.0"} + } + } + toolget_payload = { + "jsonrpc": "2.0", + "id": random.randint(100000, 999999), + "method": "tools/list", + "params": {} + } + resp1 = conn["client"].send(init_payload) + if "result" not in resp1: + continue + resp2 = conn["client"].send(toolget_payload) + if "result" not in resp2 or "tools" not in resp2["result"]: + continue + with mcp_lock: + conn["tools"] = resp2["result"]["tools"] + except Exception as e: + print(f"MCP Setup Error: {e}") + print(f"Completed load of MCP json file at '{filepath}'.") + except Exception as e: + print(f"Failed to parse MCP json file at '{filepath}': {e}") + + def unregister_koboldcpp(): try: if os.name == 'nt': @@ -7985,7 +8079,7 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): friendlymodelname = "koboldcpp/" + sanitize_string(newmdldisplayname) # horde worker settings - global maxhordelen, maxhordectx, showdebug, has_multiplayer, savedata_obj, mcp_connections + global maxhordelen, maxhordectx, showdebug, has_multiplayer, savedata_obj if args.hordemodelname and args.hordemodelname!="": friendlymodelname = args.hordemodelname if args.debugmode == 1 or args.gendefaults: @@ -8031,69 +8125,7 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): print(f"Failed to access savedatafile '{filepath}': {e}") if args.mcpfile and isinstance(args.mcpfile, str): - filepath = os.path.abspath(args.mcpfile) # Ensure it's an absolute path - if not filepath.lower().endswith(".json"): - filepath += ".json" - args.mcpfile += ".json" - try: - with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: - loaded = json.load(f) - if not isinstance(loaded, dict): - raise ValueError("MCP config must be a JSON object") - servers = loaded.get("mcpServers") - if not isinstance(servers, dict): - raise ValueError("MCP config missing 'mcpServers' object") - #start all mcp servers, initialize them and fetch their tools - mcp_connections = [] # each item is {"client":obj, "tools":[]} - for name, cfg in servers.items(): - try: - if not isinstance(cfg, dict): - raise ValueError(f"MCP server '{name}' must be an object") - mcpurl = cfg.get("url", "") #only one of these should be filled - mcpcmd = cfg.get("command","") - if mcpcmd and not mcpurl: #stdio type - mcpargs = cfg.get("args", []) - mcpenv = cfg.get("env", {}) - client = MCPStdioClient(command=mcpcmd,largs=mcpargs,env=mcpenv) - mcp_connections.append({"client":client,"tools":[]}) - elif mcpurl: - headers = cfg.get("headers", {}) - client = MCPHTTPClient(url=mcpurl, headers=headers) - mcp_connections.append({"client":client,"tools":[]}) - else: - raise ValueError(f"MCP server '{name}' missing 'command' and 'url'") - except Exception as e: - print(f"MCP Init Error: {e}") - for conn in mcp_connections: #establish init and tool for each server - try: - init_payload = { - "jsonrpc": "2.0", - "id": random.randint(100000, 999999), - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "koboldcpp", "version": "1.0.0"} - } - } - toolget_payload = { - "jsonrpc": "2.0", - "id": random.randint(100000, 999999), - "method": "tools/list", - "params": {} - } - resp1 = conn["client"].send(init_payload) - if "result" not in resp1: - continue - resp2 = conn["client"].send(toolget_payload) - if "result" not in resp2 or "tools" not in resp2["result"]: - continue - conn["tools"] = resp2["result"]["tools"] - except Exception as e: - print(f"MCP Setup Error: {e}") - print(f"Loaded existing MCP json file at '{filepath}'.") - except Exception as e: - print(f"Failed to parse MCP json file at '{filepath}': {e}") + threading.Thread(target=load_mcp_async, args=(args,), daemon=True).start() if args.highpriority: print("Setting process to Higher Priority - Use Caution")