improvements to tool calling logic (merged changes from old PR branch) (#1855)

* improvements to tool calling logic (merged changes from old PR branch)

* added some tweaks for improved tool calls to reuse old ctx, but needs testing. refer to PR.

* fixes to some stuff that concedo's modifications broke

* fixed error in reasoning

* extremely hacky way to cache tool list please fix

* oops forgot to add this

* slightly less hacky way to preserve the tool list in context

* prevented unintended toolcalls from happening when LLM states something irrelevant to toolcall decision

* fixed something that broke koboldlite

* fixed bug added by concedo that broke jinja tools

* experimental further compression of tools array, needs testing

* reverted experimental further compression of tools array

* final cleanup

* add newline after memory insert

* changed tool reasoning to always be in json format to enforce including final decision

* used new json format to skip extra llm call when not necessary

* more catching of possible bad llm output

* further cleanup

* got it down to just one llm call!

* better json format

* even better json format

* further refinement to json format

* further refinement to json format

* fixed broken tool calling

* single-call enforced json method now seems to work well. removed fallbacks as they are no longer required.

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
Rose 2025-11-23 15:41:31 +01:00 committed by GitHub
parent 2ef03a824e
commit eeb7363985
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

157
koboldcpp.py Normal file → Executable file
View file

@ -2527,34 +2527,82 @@ def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
tools_array = genparams.get('tools', [])
chosen_tool = genparams.get('tool_choice', "auto")
messages = genparams.get('messages',[])
toolmem = genparams.get("memory","")
# first handle auto mode, determine whether a tool is needed
used_tool_json = None
if not curr_ctx:
return None
# get user's last message and last tool call results
last_user_message = ""
tool_call_results = ""
if messages:
reversed_messages = list(reversed(messages))
for message in reversed_messages:
if message["role"] == "user":
last_user_message = message["content"]
last_user_message = f"\n\nUser's current request: {last_user_message}"
break
tool_call_chunk = []
for message in reversed_messages:
if message["role"] == "tool":
tool_call_chunk.append(message["content"])
else:
break
tmp_tool_replies = list(reversed(tool_call_chunk))
if tmp_tool_replies and len(tmp_tool_replies)>0:
tool_call_results = f"\n\nTool call responses: {tmp_tool_replies}"
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
tools_string = json.dumps(tools_array, indent=0)
should_use_tools = True
if chosen_tool=="auto":
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
custom_tools_prompt = "Can the user query be answered by a listed tool above? (One word response: yes or no):"
if is_followup_tool:
custom_tools_prompt = "Can the user query be further answered by another listed tool above? (If response is already complete, reply NO) (One word response: yes or no):"
if chosen_tool=="auto" or chosen_tool=="required":
# note: message string already contains the instruct start tag!
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
temptoolnames = extract_all_names_from_tool_array(tools_array)
tempjson = {}
if chosen_tool=="required":
custom_tools_prompt_json_format = "Respond with a JSON object using this structure:\r\n{\r\n \"tool_name\": \"exact_tool_name_here\"\r\n}\r\n\r\nRules:\r\n- You must pick one of the tools to use, pick the most suitable tool."
tempjson = {"type":"object","properties":{"tool_name":{"type":"string","enum":temptoolnames}},"required":["tool_name"],"additionalProperties":False}
else:
temptoolnames.append("null")
custom_tools_prompt_json_format = "Respond with a JSON object using this structure:\r\n{\r\n \"reasoning\": \"Your reasoning here\",\r\n \"final_decision\": \"yes\" or \"no\",\r\n \"tool_name\": \"exact_tool_name_here\" or \"null\"\r\n}\r\n\r\nRules:\r\n- Output only the JSON object. Do NOT add anything before or after the json object.\r\n- final_decision must be exactly \"yes\" or \"no\"\r\n- tool_name must be either an exact tool name, or if no tool is required, an empty string: \"\"\r\n- Keep reasoning short, maximum one or two sentences.\r\n- No unnecessary comments"
tempjson = {"type":"object","properties":{"reasoning":{"type":"string"},"final_decision":{"type":"string","enum":["yes","no","Yes","No","YES","NO"," yes"," no"," Yes"," No"," YES"," NO"]},"tool_name":{"type":"string","enum":temptoolnames}},"required":["reasoning","final_decision","tool_name"],"additionalProperties":False}
toolquerygrammar = convert_json_to_gbnf(tempjson)
if not is_followup_tool:
custom_tools_prompt = "Is calling one of the tools listed above absolutely essential to answer user's current request, or is a tool call optional?"
custom_tools_prompt_processed = f"{curr_ctx}{last_user_message}\n\n{custom_tools_prompt} {custom_tools_prompt_json_format}{assistant_message_start}"
else:
custom_tools_prompt = "Given the tool call response to the user's current request, is another tool call needed to further answer user's message?"
custom_tools_prompt_processed = f"{curr_ctx}{last_user_message}{tool_call_results}\n\n{custom_tools_prompt} {custom_tools_prompt_json_format}{assistant_message_start}"
# first, prompt to see if a tool call is needed using the prompt above.
# the result is a short explanation by the LLM on why a tool call is or is not needed, along with it's final decision at the end.
temp_poll = {
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{assistant_message_start}",
"max_length":5,
"prompt": custom_tools_prompt_processed,
"memory": toolmem,
"max_length":300,
"temperature":0.1,
"top_k":1,
"rep_pen":1,
"ban_eos_token":False,
"grammar":pollgrammar
}
"grammar":toolquerygrammar
}
temp_poll_result = generate(genparams=temp_poll)
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
should_use_tools = False
temp_poll_text = temp_poll_result['text'].strip().rstrip('.')
temp_poll_data_arr = extract_json_from_string(temp_poll_text)
temp_poll_data = temp_poll_data_arr[0] if (temp_poll_data_arr and len(temp_poll_data_arr)>0) else None
if temp_poll_data:
if chosen_tool!="required" and ("yes" not in temp_poll_data.get("final_decision","").lower() or "null" in temp_poll_data.get("tool_name","").lower()):
should_use_tools = False
elif (chosen_tool=="auto" or chosen_tool=="required") and "null" not in temp_poll_data.get("tool_name","").lower():
chosen_tool = temp_poll_data.get("tool_name","").lower().strip()
if not args.quiet:
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
print(f"\n[TOOLCALL REASONING]: {temp_poll_text}")
if should_use_tools:
#first, try and extract a specific tool if selected
@ -2567,38 +2615,26 @@ def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_
toolnames = extract_all_names_from_tool_array(tools_array)
if len(toolnames) == 1:
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
else:
pollgrammar = ""
for name in toolnames:
pollgrammar += ("" if pollgrammar=="" else " | ")
pollgrammar += "\"" + name + "\""
pollgrammar += " | \"no_tool\""
pollgrammar = r'root ::= ' + pollgrammar
decide_tool_prompt = "Which of the listed tools should be used next? Pick exactly one. If no tool is suitable, reply no_tool. (Reply directly with the selected tool's name):"
temp_poll = {
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{assistant_message_start}",
"max_length":16,
"temperature":0.1,
"top_k":1,
"rep_pen":1,
"ban_eos_token":False,
"grammar":pollgrammar
}
temp_poll_result = generate(genparams=temp_poll)
if temp_poll_result:
raw = temp_poll_result['text'].lower()
if "no_tool" in raw:
print(f"\nNo suitable tool found.")
else:
for name in toolnames:
if name.lower() in raw:
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
if not args.quiet:
print(f"\nAttempting to use tool: {name}")
break
return used_tool_json
def compress_tools_array(tools_array):
tools_array_filtered = []
for tool_dict in tools_array:
tool_data = tool_dict['function']
tool_props = {}
params = tool_data.get("parameters", {})
props = params.get("properties", {})
for prop_name, prop_data in props.items():
tool_props[prop_name] = prop_data['type']
tools_array_filtered.append({
"name": tool_data['name'],
"description": tool_data['description'],
"properties": tool_props
})
return tools_array_filtered
def transform_genparams(genparams, api_format, use_jinja):
global chatcompl_adapter, maxctx
@ -2704,7 +2740,7 @@ ws ::= | " " | "\n" [ \t]{0,20}
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "")
assistant_message_gen = adapter_obj.get("assistant_gen", assistant_message_start)
tools_message_start = adapter_obj.get("tools_start", "\nTool Results:\n")
tools_message_start = adapter_obj.get("tools_start", "")
tools_message_end = adapter_obj.get("tools_end", "")
images_added = []
audio_added = []
@ -2747,6 +2783,13 @@ ws ::= | " " | "\n" [ \t]{0,20}
if jinjatools and len(jinjatools)>0:
genparams["using_openai_tools"] = True
else:
if jinjatools:
# inject the tools list at the top of the context window, even if context has shifted
# uses koboldcpp's special memory parameter
tools_string = f"{system_message_start}### Available Tools:\n{json.dumps(compress_tools_array(jinjatools), indent=0)}{system_message_end}\n"
exist_mem = genparams.get('memory', "")
genparams["memory"] = tools_string + exist_mem
for message in messages_array:
message_index += 1
if message['role'] == "system":
@ -2757,6 +2800,9 @@ ws ::= | " " | "\n" [ \t]{0,20}
messages_string += assistant_message_start
elif message['role'] == "tool":
messages_string += tools_message_start
tcid = message.get("tool_call_id","")
tcid = ("" if not tcid else f" {tcid}")
messages_string += f"\nReceived results of function call{tcid}:\n"
# content can be a string or an array of objects
curr_content = message.get("content",None)
@ -2768,9 +2814,16 @@ ws ::= | " " | "\n" [ \t]{0,20}
if not curr_content:
if "tool_calls" in message:
try:
if len(message.get("tool_calls"))>0:
tcfnname = message.get("tool_calls")[0].get("function").get("name")
messages_string += f"\n(Made a function call to {tcfnname})\n"
nlstart = True
for tc in message.get("tool_calls"):
if nlstart:
nlstart = False
messages_string += "\n"
tcid = tc.get("id","")
tcfnname = tc.get("function").get("name")
tcfnargs = tc.get("function").get("arguments","")
tcfnargs = (f" with arguments={tcfnargs}" if tcfnargs else "")
messages_string += f"(Made a function call {tcid} to {tcfnname}{tcfnargs})\n"
except Exception:
messages_string += "\n(Made a function call)\n"
pass # do nothing
@ -2815,7 +2868,6 @@ ws ::= | " " | "\n" [ \t]{0,20}
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{assistant_message_start}"
if message['role'] == "system":
messages_string += system_message_end
elif message['role'] == "user":
@ -4210,10 +4262,6 @@ Change Mode<br>
is_embeddings = False
response_body = None
use_jinja = args.jinja
if use_jinja and not args.jinja_tools:
tmptools = genparams.get('tools', [])
if tmptools and len(tmptools) > 0:
use_jinja = False # not allowed to use tools with jinja
if self.path.endswith('/api/admin/check_state'):
if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword):
@ -4349,6 +4397,11 @@ Change Mode<br>
if args.debugmode >= 1:
trunc_len = 32000
if use_jinja and not args.jinja_tools:
tmptools = genparams.get('tools', [])
if tmptools and len(tmptools) > 0:
use_jinja = False # not allowed to use tools with jinja
printablegenparams_raw = truncate_long_json(genparams,trunc_len)
utfprint("\nInput: " + json.dumps(printablegenparams_raw,ensure_ascii=False),1)