mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
massively improved tool calling
This commit is contained in:
parent
c4df151298
commit
748dfcc2e4
1 changed files with 94 additions and 35 deletions
129
koboldcpp.py
129
koboldcpp.py
|
@ -1982,6 +1982,8 @@ def extract_json_from_string(input_string):
|
|||
parsed_json = None
|
||||
try: # First check if model exported perfect json
|
||||
parsed_json = json.loads(input_string)
|
||||
if not isinstance(parsed_json, list):
|
||||
parsed_json = [parsed_json]
|
||||
return parsed_json
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -1997,6 +1999,8 @@ def extract_json_from_string(input_string):
|
|||
for potential_json in potential_jsons:
|
||||
try:
|
||||
parsed_json = json.loads(potential_json)
|
||||
if not isinstance(parsed_json, list):
|
||||
parsed_json = [parsed_json]
|
||||
return parsed_json
|
||||
except Exception:
|
||||
continue
|
||||
|
@ -2039,6 +2043,35 @@ def parse_last_logprobs(lastlogprobs):
|
|||
logprobsdict['content'].append(lp_content_item)
|
||||
return logprobsdict
|
||||
|
||||
def extract_tool_info_from_tool_array(chosen_tool, tools_array):
|
||||
found_function = ""
|
||||
found_tooljson = None
|
||||
try:
|
||||
if isinstance(chosen_tool, str):
|
||||
found_function = chosen_tool
|
||||
elif isinstance(chosen_tool, dict): #if we can match the tool name, we must use that tool, remove all other tools
|
||||
found_function = chosen_tool.get('function').get('name')
|
||||
#if we find the function in tools, remove all other tools except the one matching the function name
|
||||
for tool in tools_array:
|
||||
if found_function and tool.get('type') == "function" and tool.get('function').get('name').lower() == found_function.lower():
|
||||
found_tooljson = tool
|
||||
break
|
||||
except Exception:
|
||||
# In case of any issues, just revert back to no specified function
|
||||
print("Tools parsing not valid - discarded")
|
||||
pass
|
||||
return found_tooljson
|
||||
|
||||
def extract_all_names_from_tool_array(tools_array):
|
||||
toolnames = []
|
||||
for tool in tools_array:
|
||||
try:
|
||||
if tool.get('type') == "function" and tool.get('function').get('name'):
|
||||
toolnames.append(tool.get('function').get('name'))
|
||||
except Exception:
|
||||
pass
|
||||
return toolnames
|
||||
|
||||
def transform_genparams(genparams, api_format):
|
||||
global chatcompl_adapter, maxctx
|
||||
|
||||
|
@ -2120,32 +2153,6 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0
|
|||
ws ::= | " " | "\n" [ \t]{0,20}
|
||||
"""
|
||||
|
||||
# tools handling
|
||||
tools_array = genparams.get('tools', [])
|
||||
chosen_tool = genparams.get('tool_choice', "auto")
|
||||
tool_json_formatting_instruction = "\nUse this style of JSON object formatting to give your answer if you think the user is asking you to perform an action: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": "insert the name of the function you want to call", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
||||
if tools_array and len(tools_array) > 0 and chosen_tool is not None:
|
||||
try:
|
||||
specified_function = ""
|
||||
if isinstance(chosen_tool, str):
|
||||
specified_function = chosen_tool
|
||||
elif isinstance(chosen_tool, dict): #if we can match the tool name, we must use that tool, remove all other tools
|
||||
specified_function = chosen_tool.get('function').get('name')
|
||||
located_tooljson = None
|
||||
#if we find the function in tools, remove all other tools except the one matching the function name
|
||||
for tool in tools_array:
|
||||
if specified_function and tool.get('type') == "function" and tool.get('function').get('name') == specified_function:
|
||||
located_tooljson = tool
|
||||
break
|
||||
if located_tooljson:
|
||||
tools_array = []
|
||||
tools_array.append(located_tooljson)
|
||||
tool_json_formatting_instruction = f"\nThe user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
||||
except Exception:
|
||||
# In case of any issues, just revert back to no specified function
|
||||
print("Tools parsing not valid - discarded")
|
||||
pass
|
||||
|
||||
# handle structured outputs
|
||||
respformat = genparams.get('response_format', None)
|
||||
if respformat:
|
||||
|
@ -2191,9 +2198,11 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
messages_string += "\n(Attached Image)\n"
|
||||
# If last message, add any tools calls after message content and before message end token if any
|
||||
if message['role'] == "user" and message_index == len(messages_array):
|
||||
# Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
|
||||
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
|
||||
tools_array = genparams.get('tools', [])
|
||||
chosen_tool = genparams.get('tool_choice', "auto")
|
||||
# first handle auto mode, determine whether a tool is needed
|
||||
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
|
||||
#if auto mode, determine whether a tool is needed
|
||||
tools_string = json.dumps(tools_array, indent=0)
|
||||
should_use_tools = True
|
||||
user_end = assistant_message_start
|
||||
|
@ -2218,15 +2227,64 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
|
||||
|
||||
if should_use_tools:
|
||||
messages_string += tools_string
|
||||
messages_string += tool_json_formatting_instruction
|
||||
#first, try and extract a specific tool if selected
|
||||
used_tool_json = extract_tool_info_from_tool_array(chosen_tool, tools_array)
|
||||
if used_tool_json: #already found the tool we want, remove all others
|
||||
pass
|
||||
elif len(tools_array)==1:
|
||||
used_tool_json = tools_array[0]
|
||||
else: # we have to find the tool we want the old fashioned way
|
||||
toolnames = extract_all_names_from_tool_array(tools_array)
|
||||
if len(toolnames) == 1:
|
||||
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
|
||||
else:
|
||||
pollgrammar = ""
|
||||
for name in toolnames:
|
||||
pollgrammar += ("" if pollgrammar=="" else " | ")
|
||||
pollgrammar += "\"" + name + "\""
|
||||
pollgrammar = r'root ::= ' + pollgrammar
|
||||
decide_tool_prompt = "Which of the listed tools should be used? Pick exactly one. (Reply directly with the selected tool's name):"
|
||||
temp_poll = {
|
||||
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{user_end}",
|
||||
"max_length":8,
|
||||
"temperature":0.1,
|
||||
"top_k":1,
|
||||
"rep_pen":1,
|
||||
"ban_eos_token":False,
|
||||
"grammar":pollgrammar
|
||||
}
|
||||
temp_poll_result = generate(genparams=temp_poll)
|
||||
if temp_poll_result:
|
||||
raw = temp_poll_result['text'].lower()
|
||||
for name in toolnames:
|
||||
if name.lower() in raw:
|
||||
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
|
||||
if not args.quiet:
|
||||
print(f"\nAttempting to use tool: {name}")
|
||||
break
|
||||
|
||||
if used_tool_json:
|
||||
toolparamjson = None
|
||||
toolname = None
|
||||
# Set temperature low automatically if function calling
|
||||
genparams["temperature"] = 0.1
|
||||
genparams["using_openai_tools"] = True
|
||||
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
|
||||
genparams["grammar"] = jsongrammar
|
||||
try:
|
||||
toolname = used_tool_json.get('function').get('name')
|
||||
toolparamjson = used_tool_json.get('function').get('parameters')
|
||||
bettergrammarjson = {"type":"array","items":{"type":"object","properties":{"id":{"type":"string","enum":["call_001"]},"type":{"type":"string","enum":["function"]},"function":{"type":"object","properties":{"name":{"type":"string"},"arguments":{}},"required":["name","arguments"],"additionalProperties":False}},"required":["id","type","function"],"additionalProperties":False}}
|
||||
bettergrammarjson["items"]["properties"]["function"]["properties"]["arguments"] = toolparamjson
|
||||
decoded = convert_json_to_gbnf(bettergrammarjson)
|
||||
if decoded:
|
||||
genparams["grammar"] = decoded
|
||||
except Exception:
|
||||
pass
|
||||
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
||||
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{user_end}"
|
||||
|
||||
# Set temperature low automatically if function calling
|
||||
genparams["temperature"] = 0.1
|
||||
genparams["using_openai_tools"] = True
|
||||
|
||||
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
|
||||
genparams["grammar"] = jsongrammar
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
|
@ -2480,6 +2538,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if tool_calls and len(tool_calls)>0:
|
||||
for tc in tool_calls:
|
||||
tcarg = tc.get("function",{}).get("arguments",None)
|
||||
tc["id"] = f"call_{random.randint(10000, 99999)}"
|
||||
if tcarg and not isinstance(tcarg, str):
|
||||
tc["function"]["arguments"] = json.dumps(tcarg)
|
||||
recvtxt = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue