improve jinja tool calling

This commit is contained in:
Concedo 2026-04-01 17:01:37 +08:00
parent 08cf75ce4b
commit feac72cb05

View file

@ -2954,22 +2954,68 @@ def toolcall_to_normalized_json(text): #convert weird formats into standard tool
for i in range(min(len(keys), len(values))):
params[keys[i].strip()] = values[i].strip()
return json.dumps({"name": fn_name, "arguments": params})
def parse_deepseek_r1_sep(text: str) -> str:
text = re.sub(r'<tool▁calls▁begin>(.*?)<tool▁calls▁end>', r'\1',
text, flags=re.DOTALL).strip()
sep = '<tool▁sep>'
if sep not in text:
return text
parts = [p.strip() for p in text.split(sep) if p.strip()]
results = []
for part in parts:
lines = part.split('\n', 1)
fn_name = lines[0].strip()
args_block = lines[1] if len(lines) > 1 else '{}'
args_block = re.sub(r'^```(?:json)?\s*', '', args_block.strip())
args_block = re.sub(r'\s*```$', '', args_block.strip())
try:
results.append({"name": fn_name, "arguments": json.loads(args_block)})
except Exception:
pass
if not results:
return text
return json.dumps(results) if len(results) > 1 else json.dumps(results[0])
def parse_minimax(text: str) -> str:
results = []
for invoke in re.finditer(
r'<invoke\s+name=["\']?([^"\'>\s]+)["\']?>(.*?)</invoke>',
text, re.DOTALL
):
fn_name = invoke.group(1).strip()
params = {}
for p in re.finditer(
r'<parameter\s+name=["\']?([^"\'>\s]+)["\']?>(.*?)</parameter>',
invoke.group(2), re.DOTALL
):
val = p.group(2).strip()
try:
params[p.group(1).strip()] = json.loads(val)
except Exception:
params[p.group(1).strip()] = val
results.append({"name": fn_name, "arguments": params})
if not results:
return text
return json.dumps(results) if len(results) > 1 else json.dumps(results[0])
#if we are already valid JSON, return
check_ok = extract_json_from_string(text)
if check_ok and len(check_ok)>0:
return text #is valid JSON or parsable
# handle glm with args
if "<arg_key>" in text and "<arg_value>" in text:
if "<arg_key>" in text and "<arg_value>" in text: # handle glm with args
return parse_glm(text)
# handle qwen3.5
if "<function=" in text:
if "<function=" in text: # handle qwen3.5
return parse_qwen35(text)
# handle glm without args
if ' ' not in text and '\n' not in text:
if "<invoke " in text: #minimax
return parse_minimax(text)
if '<tool▁sep>' in text: #deepseek
return parse_deepseek_r1_sep(text)
if ' ' not in text and '\n' not in text: # handle glm without args
return parse_glm(text)
return text #fallback
@ -2978,12 +3024,16 @@ def repack_toolcall_tags(text: str):
tool_calls = []
if not text:
return tool_calls
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
text = re.sub(r'<thinking>.*?</thinking>', '', text, flags=re.DOTALL)
text = re.sub(r'<reasoning>.*?</reasoning>', '', text, flags=re.DOTALL)
text = text.strip()
tcpairs = [
("<tool_call>", "</tool_call>"),
("<seed:tool_call>", "</seed:tool_call>"),
("<|tool_call_begin|>", "<|tool_call_end|>"),
("<tool▁call▁begin>", "<tool▁call▁end>")
("<tool▁call▁begin>", "<tool▁call▁end>"),
("<minimax:tool_call>", "</minimax:tool_call>"),
]
found = False
for start, end in tcpairs:
@ -3001,7 +3051,7 @@ def repack_toolcall_tags(text: str):
tool_calls = extract_json_from_string(text)
return tool_calls
def format_jinja(messages, tools, chat_template_kwargs=None):
def format_jinja(messages_orig, tools, chat_template_kwargs=None):
try:
def strftime_now(format='%Y-%m-%d %H:%M:%S'):
return datetime.now().strftime(format)
@ -3014,6 +3064,7 @@ def format_jinja(messages, tools, chat_template_kwargs=None):
from jinja2.sandbox import ImmutableSandboxedEnvironment
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
# sanitize messages to remove none types
messages = json.loads(json.dumps(messages_orig))
for m in messages:
if m.get("content") is None:
del m["content"]
@ -4285,7 +4336,13 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
# first, check and potentially segment multiple tags for multi-tool calls
tool_calls = repack_toolcall_tags(recvtxt)
if tool_calls and len(tool_calls)>0:
tool_calls = [normalize_tool_call(obj) for obj in tool_calls]
flat = []
for obj in tool_calls:
if isinstance(obj, list):
flat.extend(obj)
else:
flat.append(obj)
tool_calls = [normalize_tool_call(obj) for obj in flat]
for tc in tool_calls:
tcarg = tc.get("function",{}).get("arguments",None)
tc["id"] = f"call_{random.randint(10000, 99999)}"