From feac72cb05d85c0e8c8a9cf0e62d1fec9425146a Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:01:37 +0800 Subject: [PATCH] improve jinja tool calling --- koboldcpp.py | 75 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 648f585d5..21a60d58f 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -2954,22 +2954,68 @@ def toolcall_to_normalized_json(text): #convert weird formats into standard tool for i in range(min(len(keys), len(values))): params[keys[i].strip()] = values[i].strip() return json.dumps({"name": fn_name, "arguments": params}) + def parse_deepseek_r1_sep(text: str) -> str: + text = re.sub(r'<|tool▁calls▁begin|>(.*?)<|tool▁calls▁end|>', r'\1', + text, flags=re.DOTALL).strip() + sep = '<|tool▁sep|>' + if sep not in text: + return text + parts = [p.strip() for p in text.split(sep) if p.strip()] + results = [] + for part in parts: + lines = part.split('\n', 1) + fn_name = lines[0].strip() + args_block = lines[1] if len(lines) > 1 else '{}' + args_block = re.sub(r'^```(?:json)?\s*', '', args_block.strip()) + args_block = re.sub(r'\s*```$', '', args_block.strip()) + try: + results.append({"name": fn_name, "arguments": json.loads(args_block)}) + except Exception: + pass + if not results: + return text + return json.dumps(results) if len(results) > 1 else json.dumps(results[0]) + def parse_minimax(text: str) -> str: + results = [] + for invoke in re.finditer( + r'\s]+)["\']?>(.*?)', + text, re.DOTALL + ): + fn_name = invoke.group(1).strip() + params = {} + for p in re.finditer( + r'\s]+)["\']?>(.*?)', + invoke.group(2), re.DOTALL + ): + val = p.group(2).strip() + try: + params[p.group(1).strip()] = json.loads(val) + except Exception: + params[p.group(1).strip()] = val + results.append({"name": fn_name, "arguments": params}) + if not results: + return text + return json.dumps(results) if len(results) > 1 else json.dumps(results[0]) + #if we are already valid JSON, return check_ok = extract_json_from_string(text) if check_ok and len(check_ok)>0: return text #is valid JSON or parsable - # handle glm with args - if "" in text and "" in text: + if "" in text and "" in text: # handle glm with args return parse_glm(text) - # handle qwen3.5 - if "' in text: #deepseek + return parse_deepseek_r1_sep(text) + + if ' ' not in text and '\n' not in text: # handle glm without args return parse_glm(text) return text #fallback @@ -2978,12 +3024,16 @@ def repack_toolcall_tags(text: str): tool_calls = [] if not text: return tool_calls + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) + text = re.sub(r'.*?', '', text, flags=re.DOTALL) text = text.strip() tcpairs = [ ("", ""), ("", ""), ("<|tool_call_begin|>", "<|tool_call_end|>"), - ("<|tool▁call▁begin|>", "<|tool▁call▁end|>") + ("<|tool▁call▁begin|>", "<|tool▁call▁end|>"), + ("", ""), ] found = False for start, end in tcpairs: @@ -3001,7 +3051,7 @@ def repack_toolcall_tags(text: str): tool_calls = extract_json_from_string(text) return tool_calls -def format_jinja(messages, tools, chat_template_kwargs=None): +def format_jinja(messages_orig, tools, chat_template_kwargs=None): try: def strftime_now(format='%Y-%m-%d %H:%M:%S'): return datetime.now().strftime(format) @@ -3014,6 +3064,7 @@ def format_jinja(messages, tools, chat_template_kwargs=None): from jinja2.sandbox import ImmutableSandboxedEnvironment jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) # sanitize messages to remove none types + messages = json.loads(json.dumps(messages_orig)) for m in messages: if m.get("content") is None: del m["content"] @@ -4285,7 +4336,13 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler): # first, check and potentially segment multiple tags for multi-tool calls tool_calls = repack_toolcall_tags(recvtxt) if tool_calls and len(tool_calls)>0: - tool_calls = [normalize_tool_call(obj) for obj in tool_calls] + flat = [] + for obj in tool_calls: + if isinstance(obj, list): + flat.extend(obj) + else: + flat.append(obj) + tool_calls = [normalize_tool_call(obj) for obj in flat] for tc in tool_calls: tcarg = tc.get("function",{}).get("arguments",None) tc["id"] = f"call_{random.randint(10000, 99999)}"