mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-06-02 07:19:23 +00:00
improvements to tool calling logic (merged changes from old PR branch) (#1855)
* improvements to tool calling logic (merged changes from old PR branch) * added some tweaks for improved tool calls to reuse old ctx, but needs testing. refer to PR. * fixes to some stuff that concedo's modifications broke * fixed error in reasoning * extremely hacky way to cache tool list please fix * oops forgot to add this * slightly less hacky way to preserve the tool list in context * prevented unintended toolcalls from happening when LLM states something irrelevant to toolcall decision * fixed something that broke koboldlite * fixed bug added by concedo that broke jinja tools * experimental further compression of tools array, needs testing * reverted experimental further compression of tools array * final cleanup * add newline after memory insert * changed tool reasoning to always be in json format to enforce including final decision * used new json format to skip extra llm call when not necessary * more catching of possible bad llm output * further cleanup * got it down to just one llm call! * better json format * even better json format * further refinement to json format * further refinement to json format * fixed broken tool calling * single-call enforced json method now seems to work well. removed fallbacks as they are no longer required. --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
2ef03a824e
commit
eeb7363985
1 changed files with 105 additions and 52 deletions
157
koboldcpp.py
Normal file → Executable file
157
koboldcpp.py
Normal file → Executable file
|
|
@ -2527,34 +2527,82 @@ def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_
|
|||
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
|
||||
tools_array = genparams.get('tools', [])
|
||||
chosen_tool = genparams.get('tool_choice', "auto")
|
||||
messages = genparams.get('messages',[])
|
||||
toolmem = genparams.get("memory","")
|
||||
|
||||
# first handle auto mode, determine whether a tool is needed
|
||||
used_tool_json = None
|
||||
if not curr_ctx:
|
||||
return None
|
||||
|
||||
# get user's last message and last tool call results
|
||||
last_user_message = ""
|
||||
tool_call_results = ""
|
||||
|
||||
if messages:
|
||||
reversed_messages = list(reversed(messages))
|
||||
for message in reversed_messages:
|
||||
if message["role"] == "user":
|
||||
last_user_message = message["content"]
|
||||
last_user_message = f"\n\nUser's current request: {last_user_message}"
|
||||
break
|
||||
tool_call_chunk = []
|
||||
for message in reversed_messages:
|
||||
if message["role"] == "tool":
|
||||
tool_call_chunk.append(message["content"])
|
||||
else:
|
||||
break
|
||||
tmp_tool_replies = list(reversed(tool_call_chunk))
|
||||
if tmp_tool_replies and len(tmp_tool_replies)>0:
|
||||
tool_call_results = f"\n\nTool call responses: {tmp_tool_replies}"
|
||||
|
||||
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
|
||||
tools_string = json.dumps(tools_array, indent=0)
|
||||
should_use_tools = True
|
||||
if chosen_tool=="auto":
|
||||
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
|
||||
custom_tools_prompt = "Can the user query be answered by a listed tool above? (One word response: yes or no):"
|
||||
if is_followup_tool:
|
||||
custom_tools_prompt = "Can the user query be further answered by another listed tool above? (If response is already complete, reply NO) (One word response: yes or no):"
|
||||
if chosen_tool=="auto" or chosen_tool=="required":
|
||||
# note: message string already contains the instruct start tag!
|
||||
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
|
||||
temptoolnames = extract_all_names_from_tool_array(tools_array)
|
||||
tempjson = {}
|
||||
if chosen_tool=="required":
|
||||
custom_tools_prompt_json_format = "Respond with a JSON object using this structure:\r\n{\r\n \"tool_name\": \"exact_tool_name_here\"\r\n}\r\n\r\nRules:\r\n- You must pick one of the tools to use, pick the most suitable tool."
|
||||
tempjson = {"type":"object","properties":{"tool_name":{"type":"string","enum":temptoolnames}},"required":["tool_name"],"additionalProperties":False}
|
||||
else:
|
||||
temptoolnames.append("null")
|
||||
custom_tools_prompt_json_format = "Respond with a JSON object using this structure:\r\n{\r\n \"reasoning\": \"Your reasoning here\",\r\n \"final_decision\": \"yes\" or \"no\",\r\n \"tool_name\": \"exact_tool_name_here\" or \"null\"\r\n}\r\n\r\nRules:\r\n- Output only the JSON object. Do NOT add anything before or after the json object.\r\n- final_decision must be exactly \"yes\" or \"no\"\r\n- tool_name must be either an exact tool name, or if no tool is required, an empty string: \"\"\r\n- Keep reasoning short, maximum one or two sentences.\r\n- No unnecessary comments"
|
||||
tempjson = {"type":"object","properties":{"reasoning":{"type":"string"},"final_decision":{"type":"string","enum":["yes","no","Yes","No","YES","NO"," yes"," no"," Yes"," No"," YES"," NO"]},"tool_name":{"type":"string","enum":temptoolnames}},"required":["reasoning","final_decision","tool_name"],"additionalProperties":False}
|
||||
toolquerygrammar = convert_json_to_gbnf(tempjson)
|
||||
|
||||
if not is_followup_tool:
|
||||
custom_tools_prompt = "Is calling one of the tools listed above absolutely essential to answer user's current request, or is a tool call optional?"
|
||||
custom_tools_prompt_processed = f"{curr_ctx}{last_user_message}\n\n{custom_tools_prompt} {custom_tools_prompt_json_format}{assistant_message_start}"
|
||||
else:
|
||||
custom_tools_prompt = "Given the tool call response to the user's current request, is another tool call needed to further answer user's message?"
|
||||
custom_tools_prompt_processed = f"{curr_ctx}{last_user_message}{tool_call_results}\n\n{custom_tools_prompt} {custom_tools_prompt_json_format}{assistant_message_start}"
|
||||
|
||||
# first, prompt to see if a tool call is needed using the prompt above.
|
||||
# the result is a short explanation by the LLM on why a tool call is or is not needed, along with it's final decision at the end.
|
||||
temp_poll = {
|
||||
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{assistant_message_start}",
|
||||
"max_length":5,
|
||||
"prompt": custom_tools_prompt_processed,
|
||||
"memory": toolmem,
|
||||
"max_length":300,
|
||||
"temperature":0.1,
|
||||
"top_k":1,
|
||||
"rep_pen":1,
|
||||
"ban_eos_token":False,
|
||||
"grammar":pollgrammar
|
||||
}
|
||||
"grammar":toolquerygrammar
|
||||
}
|
||||
temp_poll_result = generate(genparams=temp_poll)
|
||||
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
|
||||
should_use_tools = False
|
||||
temp_poll_text = temp_poll_result['text'].strip().rstrip('.')
|
||||
temp_poll_data_arr = extract_json_from_string(temp_poll_text)
|
||||
temp_poll_data = temp_poll_data_arr[0] if (temp_poll_data_arr and len(temp_poll_data_arr)>0) else None
|
||||
|
||||
if temp_poll_data:
|
||||
if chosen_tool!="required" and ("yes" not in temp_poll_data.get("final_decision","").lower() or "null" in temp_poll_data.get("tool_name","").lower()):
|
||||
should_use_tools = False
|
||||
elif (chosen_tool=="auto" or chosen_tool=="required") and "null" not in temp_poll_data.get("tool_name","").lower():
|
||||
chosen_tool = temp_poll_data.get("tool_name","").lower().strip()
|
||||
|
||||
if not args.quiet:
|
||||
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
|
||||
print(f"\n[TOOLCALL REASONING]: {temp_poll_text}")
|
||||
|
||||
if should_use_tools:
|
||||
#first, try and extract a specific tool if selected
|
||||
|
|
@ -2567,38 +2615,26 @@ def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_
|
|||
toolnames = extract_all_names_from_tool_array(tools_array)
|
||||
if len(toolnames) == 1:
|
||||
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
|
||||
else:
|
||||
pollgrammar = ""
|
||||
for name in toolnames:
|
||||
pollgrammar += ("" if pollgrammar=="" else " | ")
|
||||
pollgrammar += "\"" + name + "\""
|
||||
pollgrammar += " | \"no_tool\""
|
||||
pollgrammar = r'root ::= ' + pollgrammar
|
||||
decide_tool_prompt = "Which of the listed tools should be used next? Pick exactly one. If no tool is suitable, reply no_tool. (Reply directly with the selected tool's name):"
|
||||
temp_poll = {
|
||||
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{assistant_message_start}",
|
||||
"max_length":16,
|
||||
"temperature":0.1,
|
||||
"top_k":1,
|
||||
"rep_pen":1,
|
||||
"ban_eos_token":False,
|
||||
"grammar":pollgrammar
|
||||
}
|
||||
temp_poll_result = generate(genparams=temp_poll)
|
||||
if temp_poll_result:
|
||||
raw = temp_poll_result['text'].lower()
|
||||
if "no_tool" in raw:
|
||||
print(f"\nNo suitable tool found.")
|
||||
else:
|
||||
for name in toolnames:
|
||||
if name.lower() in raw:
|
||||
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
|
||||
if not args.quiet:
|
||||
print(f"\nAttempting to use tool: {name}")
|
||||
break
|
||||
|
||||
return used_tool_json
|
||||
|
||||
def compress_tools_array(tools_array):
|
||||
tools_array_filtered = []
|
||||
for tool_dict in tools_array:
|
||||
tool_data = tool_dict['function']
|
||||
tool_props = {}
|
||||
params = tool_data.get("parameters", {})
|
||||
props = params.get("properties", {})
|
||||
for prop_name, prop_data in props.items():
|
||||
tool_props[prop_name] = prop_data['type']
|
||||
tools_array_filtered.append({
|
||||
"name": tool_data['name'],
|
||||
"description": tool_data['description'],
|
||||
"properties": tool_props
|
||||
})
|
||||
|
||||
return tools_array_filtered
|
||||
|
||||
def transform_genparams(genparams, api_format, use_jinja):
|
||||
global chatcompl_adapter, maxctx
|
||||
|
||||
|
|
@ -2704,7 +2740,7 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
|
||||
assistant_message_end = adapter_obj.get("assistant_end", "")
|
||||
assistant_message_gen = adapter_obj.get("assistant_gen", assistant_message_start)
|
||||
tools_message_start = adapter_obj.get("tools_start", "\nTool Results:\n")
|
||||
tools_message_start = adapter_obj.get("tools_start", "")
|
||||
tools_message_end = adapter_obj.get("tools_end", "")
|
||||
images_added = []
|
||||
audio_added = []
|
||||
|
|
@ -2747,6 +2783,13 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
if jinjatools and len(jinjatools)>0:
|
||||
genparams["using_openai_tools"] = True
|
||||
else:
|
||||
if jinjatools:
|
||||
# inject the tools list at the top of the context window, even if context has shifted
|
||||
# uses koboldcpp's special memory parameter
|
||||
tools_string = f"{system_message_start}### Available Tools:\n{json.dumps(compress_tools_array(jinjatools), indent=0)}{system_message_end}\n"
|
||||
exist_mem = genparams.get('memory', "")
|
||||
genparams["memory"] = tools_string + exist_mem
|
||||
|
||||
for message in messages_array:
|
||||
message_index += 1
|
||||
if message['role'] == "system":
|
||||
|
|
@ -2757,6 +2800,9 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
messages_string += assistant_message_start
|
||||
elif message['role'] == "tool":
|
||||
messages_string += tools_message_start
|
||||
tcid = message.get("tool_call_id","")
|
||||
tcid = ("" if not tcid else f" {tcid}")
|
||||
messages_string += f"\nReceived results of function call{tcid}:\n"
|
||||
|
||||
# content can be a string or an array of objects
|
||||
curr_content = message.get("content",None)
|
||||
|
|
@ -2768,9 +2814,16 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
if not curr_content:
|
||||
if "tool_calls" in message:
|
||||
try:
|
||||
if len(message.get("tool_calls"))>0:
|
||||
tcfnname = message.get("tool_calls")[0].get("function").get("name")
|
||||
messages_string += f"\n(Made a function call to {tcfnname})\n"
|
||||
nlstart = True
|
||||
for tc in message.get("tool_calls"):
|
||||
if nlstart:
|
||||
nlstart = False
|
||||
messages_string += "\n"
|
||||
tcid = tc.get("id","")
|
||||
tcfnname = tc.get("function").get("name")
|
||||
tcfnargs = tc.get("function").get("arguments","")
|
||||
tcfnargs = (f" with arguments={tcfnargs}" if tcfnargs else "")
|
||||
messages_string += f"(Made a function call {tcid} to {tcfnname}{tcfnargs})\n"
|
||||
except Exception:
|
||||
messages_string += "\n(Made a function call)\n"
|
||||
pass # do nothing
|
||||
|
|
@ -2815,7 +2868,6 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
||||
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{assistant_message_start}"
|
||||
|
||||
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
|
|
@ -4210,10 +4262,6 @@ Change Mode<br>
|
|||
is_embeddings = False
|
||||
response_body = None
|
||||
use_jinja = args.jinja
|
||||
if use_jinja and not args.jinja_tools:
|
||||
tmptools = genparams.get('tools', [])
|
||||
if tmptools and len(tmptools) > 0:
|
||||
use_jinja = False # not allowed to use tools with jinja
|
||||
|
||||
if self.path.endswith('/api/admin/check_state'):
|
||||
if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword):
|
||||
|
|
@ -4349,6 +4397,11 @@ Change Mode<br>
|
|||
if args.debugmode >= 1:
|
||||
trunc_len = 32000
|
||||
|
||||
if use_jinja and not args.jinja_tools:
|
||||
tmptools = genparams.get('tools', [])
|
||||
if tmptools and len(tmptools) > 0:
|
||||
use_jinja = False # not allowed to use tools with jinja
|
||||
|
||||
printablegenparams_raw = truncate_long_json(genparams,trunc_len)
|
||||
utfprint("\nInput: " + json.dumps(printablegenparams_raw,ensure_ascii=False),1)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue