From feac72cb05d85c0e8c8a9cf0e62d1fec9425146a Mon Sep 17 00:00:00 2001
From: Concedo <39025047+LostRuins@users.noreply.github.com>
Date: Wed, 1 Apr 2026 17:01:37 +0800
Subject: [PATCH] improve jinja tool calling
---
koboldcpp.py | 75 +++++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 66 insertions(+), 9 deletions(-)
diff --git a/koboldcpp.py b/koboldcpp.py
index 648f585d5..21a60d58f 100755
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -2954,22 +2954,68 @@ def toolcall_to_normalized_json(text): #convert weird formats into standard tool
for i in range(min(len(keys), len(values))):
params[keys[i].strip()] = values[i].strip()
return json.dumps({"name": fn_name, "arguments": params})
+ def parse_deepseek_r1_sep(text: str) -> str:
+ text = re.sub(r'<|tool▁calls▁begin|>(.*?)<|tool▁calls▁end|>', r'\1',
+ text, flags=re.DOTALL).strip()
+ sep = '<|tool▁sep|>'
+ if sep not in text:
+ return text
+ parts = [p.strip() for p in text.split(sep) if p.strip()]
+ results = []
+ for part in parts:
+ lines = part.split('\n', 1)
+ fn_name = lines[0].strip()
+ args_block = lines[1] if len(lines) > 1 else '{}'
+ args_block = re.sub(r'^```(?:json)?\s*', '', args_block.strip())
+ args_block = re.sub(r'\s*```$', '', args_block.strip())
+ try:
+ results.append({"name": fn_name, "arguments": json.loads(args_block)})
+ except Exception:
+ pass
+ if not results:
+ return text
+ return json.dumps(results) if len(results) > 1 else json.dumps(results[0])
+ def parse_minimax(text: str) -> str:
+ results = []
+ for invoke in re.finditer(
+ r'\s]+)["\']?>(.*?)',
+ text, re.DOTALL
+ ):
+ fn_name = invoke.group(1).strip()
+ params = {}
+ for p in re.finditer(
+ r'\s]+)["\']?>(.*?)',
+ invoke.group(2), re.DOTALL
+ ):
+ val = p.group(2).strip()
+ try:
+ params[p.group(1).strip()] = json.loads(val)
+ except Exception:
+ params[p.group(1).strip()] = val
+ results.append({"name": fn_name, "arguments": params})
+ if not results:
+ return text
+ return json.dumps(results) if len(results) > 1 else json.dumps(results[0])
+
#if we are already valid JSON, return
check_ok = extract_json_from_string(text)
if check_ok and len(check_ok)>0:
return text #is valid JSON or parsable
- # handle glm with args
- if "" in text and "" in text:
+ if "" in text and "" in text: # handle glm with args
return parse_glm(text)
- # handle qwen3.5
- if "' in text: #deepseek
+ return parse_deepseek_r1_sep(text)
+
+ if ' ' not in text and '\n' not in text: # handle glm without args
return parse_glm(text)
return text #fallback
@@ -2978,12 +3024,16 @@ def repack_toolcall_tags(text: str):
tool_calls = []
if not text:
return tool_calls
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL)
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL)
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL)
text = text.strip()
tcpairs = [
("", ""),
("", ""),
("<|tool_call_begin|>", "<|tool_call_end|>"),
- ("<|tool▁call▁begin|>", "<|tool▁call▁end|>")
+ ("<|tool▁call▁begin|>", "<|tool▁call▁end|>"),
+ ("", ""),
]
found = False
for start, end in tcpairs:
@@ -3001,7 +3051,7 @@ def repack_toolcall_tags(text: str):
tool_calls = extract_json_from_string(text)
return tool_calls
-def format_jinja(messages, tools, chat_template_kwargs=None):
+def format_jinja(messages_orig, tools, chat_template_kwargs=None):
try:
def strftime_now(format='%Y-%m-%d %H:%M:%S'):
return datetime.now().strftime(format)
@@ -3014,6 +3064,7 @@ def format_jinja(messages, tools, chat_template_kwargs=None):
from jinja2.sandbox import ImmutableSandboxedEnvironment
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
# sanitize messages to remove none types
+ messages = json.loads(json.dumps(messages_orig))
for m in messages:
if m.get("content") is None:
del m["content"]
@@ -4285,7 +4336,13 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
# first, check and potentially segment multiple tags for multi-tool calls
tool_calls = repack_toolcall_tags(recvtxt)
if tool_calls and len(tool_calls)>0:
- tool_calls = [normalize_tool_call(obj) for obj in tool_calls]
+ flat = []
+ for obj in tool_calls:
+ if isinstance(obj, list):
+ flat.extend(obj)
+ else:
+ flat.append(obj)
+ tool_calls = [normalize_tool_call(obj) for obj in flat]
for tc in tool_calls:
tcarg = tc.get("function",{}).get("arguments",None)
tc["id"] = f"call_{random.randint(10000, 99999)}"