mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 17:22:04 +00:00
improve jinja tool calling
This commit is contained in:
parent
08cf75ce4b
commit
feac72cb05
1 changed files with 66 additions and 9 deletions
75
koboldcpp.py
75
koboldcpp.py
|
|
@ -2954,22 +2954,68 @@ def toolcall_to_normalized_json(text): #convert weird formats into standard tool
|
|||
for i in range(min(len(keys), len(values))):
|
||||
params[keys[i].strip()] = values[i].strip()
|
||||
return json.dumps({"name": fn_name, "arguments": params})
|
||||
def parse_deepseek_r1_sep(text: str) -> str:
|
||||
text = re.sub(r'<|tool▁calls▁begin|>(.*?)<|tool▁calls▁end|>', r'\1',
|
||||
text, flags=re.DOTALL).strip()
|
||||
sep = '<|tool▁sep|>'
|
||||
if sep not in text:
|
||||
return text
|
||||
parts = [p.strip() for p in text.split(sep) if p.strip()]
|
||||
results = []
|
||||
for part in parts:
|
||||
lines = part.split('\n', 1)
|
||||
fn_name = lines[0].strip()
|
||||
args_block = lines[1] if len(lines) > 1 else '{}'
|
||||
args_block = re.sub(r'^```(?:json)?\s*', '', args_block.strip())
|
||||
args_block = re.sub(r'\s*```$', '', args_block.strip())
|
||||
try:
|
||||
results.append({"name": fn_name, "arguments": json.loads(args_block)})
|
||||
except Exception:
|
||||
pass
|
||||
if not results:
|
||||
return text
|
||||
return json.dumps(results) if len(results) > 1 else json.dumps(results[0])
|
||||
def parse_minimax(text: str) -> str:
|
||||
results = []
|
||||
for invoke in re.finditer(
|
||||
r'<invoke\s+name=["\']?([^"\'>\s]+)["\']?>(.*?)</invoke>',
|
||||
text, re.DOTALL
|
||||
):
|
||||
fn_name = invoke.group(1).strip()
|
||||
params = {}
|
||||
for p in re.finditer(
|
||||
r'<parameter\s+name=["\']?([^"\'>\s]+)["\']?>(.*?)</parameter>',
|
||||
invoke.group(2), re.DOTALL
|
||||
):
|
||||
val = p.group(2).strip()
|
||||
try:
|
||||
params[p.group(1).strip()] = json.loads(val)
|
||||
except Exception:
|
||||
params[p.group(1).strip()] = val
|
||||
results.append({"name": fn_name, "arguments": params})
|
||||
if not results:
|
||||
return text
|
||||
return json.dumps(results) if len(results) > 1 else json.dumps(results[0])
|
||||
|
||||
|
||||
#if we are already valid JSON, return
|
||||
check_ok = extract_json_from_string(text)
|
||||
if check_ok and len(check_ok)>0:
|
||||
return text #is valid JSON or parsable
|
||||
|
||||
# handle glm with args
|
||||
if "<arg_key>" in text and "<arg_value>" in text:
|
||||
if "<arg_key>" in text and "<arg_value>" in text: # handle glm with args
|
||||
return parse_glm(text)
|
||||
|
||||
# handle qwen3.5
|
||||
if "<function=" in text:
|
||||
if "<function=" in text: # handle qwen3.5
|
||||
return parse_qwen35(text)
|
||||
|
||||
# handle glm without args
|
||||
if ' ' not in text and '\n' not in text:
|
||||
if "<invoke " in text: #minimax
|
||||
return parse_minimax(text)
|
||||
|
||||
if '<|tool▁sep|>' in text: #deepseek
|
||||
return parse_deepseek_r1_sep(text)
|
||||
|
||||
if ' ' not in text and '\n' not in text: # handle glm without args
|
||||
return parse_glm(text)
|
||||
|
||||
return text #fallback
|
||||
|
|
@ -2978,12 +3024,16 @@ def repack_toolcall_tags(text: str):
|
|||
tool_calls = []
|
||||
if not text:
|
||||
return tool_calls
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'<thinking>.*?</thinking>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'<reasoning>.*?</reasoning>', '', text, flags=re.DOTALL)
|
||||
text = text.strip()
|
||||
tcpairs = [
|
||||
("<tool_call>", "</tool_call>"),
|
||||
("<seed:tool_call>", "</seed:tool_call>"),
|
||||
("<|tool_call_begin|>", "<|tool_call_end|>"),
|
||||
("<|tool▁call▁begin|>", "<|tool▁call▁end|>")
|
||||
("<|tool▁call▁begin|>", "<|tool▁call▁end|>"),
|
||||
("<minimax:tool_call>", "</minimax:tool_call>"),
|
||||
]
|
||||
found = False
|
||||
for start, end in tcpairs:
|
||||
|
|
@ -3001,7 +3051,7 @@ def repack_toolcall_tags(text: str):
|
|||
tool_calls = extract_json_from_string(text)
|
||||
return tool_calls
|
||||
|
||||
def format_jinja(messages, tools, chat_template_kwargs=None):
|
||||
def format_jinja(messages_orig, tools, chat_template_kwargs=None):
|
||||
try:
|
||||
def strftime_now(format='%Y-%m-%d %H:%M:%S'):
|
||||
return datetime.now().strftime(format)
|
||||
|
|
@ -3014,6 +3064,7 @@ def format_jinja(messages, tools, chat_template_kwargs=None):
|
|||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
# sanitize messages to remove none types
|
||||
messages = json.loads(json.dumps(messages_orig))
|
||||
for m in messages:
|
||||
if m.get("content") is None:
|
||||
del m["content"]
|
||||
|
|
@ -4285,7 +4336,13 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
# first, check and potentially segment multiple tags for multi-tool calls
|
||||
tool_calls = repack_toolcall_tags(recvtxt)
|
||||
if tool_calls and len(tool_calls)>0:
|
||||
tool_calls = [normalize_tool_call(obj) for obj in tool_calls]
|
||||
flat = []
|
||||
for obj in tool_calls:
|
||||
if isinstance(obj, list):
|
||||
flat.extend(obj)
|
||||
else:
|
||||
flat.append(obj)
|
||||
tool_calls = [normalize_tool_call(obj) for obj in flat]
|
||||
for tc in tool_calls:
|
||||
tcarg = tc.get("function",{}).get("arguments",None)
|
||||
tc["id"] = f"call_{random.randint(10000, 99999)}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue