mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Rudimentary support of openai chat completions tools calls (#981)
* Rudimentary support of openai chat completions tools calls -Most small models are not smart enough to do this, especially a combined tool call + role play response, but at least this allows experimentation along these lines with koboldcpp * try to also support specified function and tool choice set to none Allow tools start and end messages to be configured in adapter Try to force grammar to specific function call if specified (untested) * ensure tools get listed right after user content and before end of user message content * omit grammars approach try prompting instead -use more extensive json parsing and direct instructions to models to try to obtain the desired result -seems to work relatively well with Mistral-7B-Instruct-v.0.3.Q4_K_M.gguf and neuralhermes-2.5-mistral-7b.Q4_K_M.gguf -question of whether this is too opinionated of an approach, should the instructions be things that can be passed with the prompt template? * add back llamacpp recommended json grammar Go back to adding grammar but use "official" llamacpp grammar only not a custom one just for openai * Tidy up, remove unnecessary globals * clarity * fix missing local variable error This worked to fix the error I mentioned on my last comment --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
5caf5f9770
commit
c08309e773
1 changed files with 107 additions and 18 deletions
125
koboldcpp.py
125
koboldcpp.py
|
@ -829,6 +829,33 @@ currfinishreason = "null"
|
|||
using_gui_launcher = False
|
||||
using_outdated_flags = False
|
||||
|
||||
# Used to parse json for openai tool calls
|
||||
def extract_json_from_string(input_string):
|
||||
parsed_json = None
|
||||
try: # First check if model exported perfect json
|
||||
parsed_json = json.loads(input_string)
|
||||
return parsed_json
|
||||
except Exception as e:
|
||||
pass
|
||||
try: # Next check if all we need is to add brackets to make it perfect json
|
||||
parsed_json = json.loads(f"[{input_string}]")
|
||||
return parsed_json
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
# Now use regular expression to match JSON objects or arrays in case part is valid json and part is not
|
||||
json_pattern = r'(\{.*?\}|\[.*?\])' # was json_pattern = r'(\{.*\}|\[.*\])'
|
||||
potential_jsons = re.findall(json_pattern, input_string, re.DOTALL)
|
||||
for potential_json in potential_jsons:
|
||||
try:
|
||||
parsed_json = json.loads(potential_json)
|
||||
return parsed_json
|
||||
except Exception as e:
|
||||
continue
|
||||
except Exception as e:
|
||||
pass
|
||||
return []
|
||||
|
||||
def transform_genparams(genparams, api_format):
|
||||
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
|
||||
#alias all nonstandard alternative names for rep pen.
|
||||
|
@ -873,15 +900,21 @@ def transform_genparams(genparams, api_format):
|
|||
user_message_end = adapter_obj.get("user_end", "")
|
||||
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
|
||||
assistant_message_end = adapter_obj.get("assistant_end", "")
|
||||
tools_message_start = adapter_obj.get("tools_start", "")
|
||||
tools_message_end = adapter_obj.get("tools_end", "")
|
||||
images_added = []
|
||||
|
||||
message_index = 0
|
||||
for message in messages_array:
|
||||
message_index += 1
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_start
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_start
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_start
|
||||
elif message['role'] == "tool":
|
||||
messages_string += tools_message_start
|
||||
|
||||
# content can be a string or an array of objects
|
||||
curr_content = message['content']
|
||||
|
@ -894,13 +927,64 @@ def transform_genparams(genparams, api_format):
|
|||
elif item['type']=="image_url":
|
||||
if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"):
|
||||
images_added.append(item['image_url']['url'].split(",", 1)[1])
|
||||
# If last message, add any tools calls after message content and before message end token if any
|
||||
if message['role'] == "user" and message_index == len(messages_array):
|
||||
# Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
|
||||
tools_array = genparams.get('tools', [])
|
||||
if tools_array and len(tools_array) > 0 and genparams.get('tool_choice',None) != None:
|
||||
response_array = [{"id": "insert an id for the response", "type": "function", "function": {"name": "insert the name of the function you want to call", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}]
|
||||
json_formatting_instruction = " Use this style of JSON object formatting to give your answer if you think the user is asking you to perform an action: " + json.dumps(response_array, indent=0)
|
||||
tools_string = json.dumps(tools_array, indent=0)
|
||||
messages_string += tools_string
|
||||
specified_function = None
|
||||
if isinstance(genparams.get('tool_choice'), dict):
|
||||
try:
|
||||
specified_function = genparams.get('tool_choice').get('function').get('name')
|
||||
json_formatting_instruction = f"The user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
||||
except Exception as e:
|
||||
# In case of any issues, just revert back to no specified function
|
||||
pass
|
||||
messages_string += json_formatting_instruction
|
||||
|
||||
# Set temperature low automatically if function calling
|
||||
genparams["temperature"] = 0.2
|
||||
genparams["using_openai_tools"] = True
|
||||
|
||||
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
|
||||
genparams["grammar"] = r"""
|
||||
root ::= arr
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
arr ::=
|
||||
"[\n" ws (
|
||||
value
|
||||
(",\n" ws value)*
|
||||
)? "]"
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
)* "\"" ws
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
|
||||
ws ::= | " " | "\n" [ \t]{0,20}
|
||||
"""
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_end
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_end
|
||||
elif message['role'] == "tool":
|
||||
messages_string += tools_message_end
|
||||
|
||||
messages_string += assistant_message_start
|
||||
genparams["prompt"] = messages_string
|
||||
|
@ -913,6 +997,7 @@ def transform_genparams(genparams, api_format):
|
|||
genparams["stop_sequence"].append(assistant_message_start.strip())
|
||||
genparams["trim_stop"] = True
|
||||
|
||||
|
||||
elif api_format==5:
|
||||
firstimg = genparams.get('image', "")
|
||||
genparams["images"] = [firstimg]
|
||||
|
@ -963,9 +1048,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
is_quiet = args.quiet
|
||||
currfinishreason = "null"
|
||||
|
||||
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||
|
||||
#flag instance as non-idle for a while
|
||||
def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||
# flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
if not washordereq:
|
||||
global last_non_horde_req_time
|
||||
|
@ -1013,9 +1097,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
render_special=genparams.get('render_special', False),
|
||||
banned_tokens=genparams.get('banned_tokens', []),
|
||||
bypass_eos_token=genparams.get('bypass_eos', False),
|
||||
)
|
||||
)
|
||||
|
||||
genout = {"text":"","status":-1,"stopreason":-1}
|
||||
genout = {"text": "", "status": -1, "stopreason": -1}
|
||||
if stream_flag:
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
@ -1024,9 +1108,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
genout = run_blocking()
|
||||
|
||||
recvtxt = genout['text']
|
||||
currfinishreason = ("length" if (genout['stopreason']!=1) else "stop")
|
||||
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
|
||||
|
||||
#flag instance as non-idle for a while
|
||||
# flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
if not washordereq:
|
||||
global last_non_horde_req_time
|
||||
|
@ -1035,27 +1119,32 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1:
|
||||
utfprint("\nOutput: " + recvtxt)
|
||||
|
||||
if api_format==1:
|
||||
res = {"data": {"seqs":[recvtxt]}}
|
||||
elif api_format==3:
|
||||
if api_format == 1:
|
||||
res = {"data": {"seqs": [recvtxt]}}
|
||||
elif api_format == 3:
|
||||
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
|
||||
elif api_format==4:
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
|
||||
elif api_format == 4:
|
||||
using_openai_tools = genparams.get('using_openai_tools', False)
|
||||
tool_calls = []
|
||||
if using_openai_tools:
|
||||
tool_calls = extract_json_from_string(recvtxt)
|
||||
if tool_calls and len(tool_calls)>0:
|
||||
recvtxt = None
|
||||
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
|
||||
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": currfinishreason}]}
|
||||
elif api_format==5:
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]}
|
||||
elif api_format == 5:
|
||||
res = {"caption": end_trim_to_sentence(recvtxt)}
|
||||
else:
|
||||
res = {"results": [{"text": recvtxt, "finish_reason":currfinishreason}]}
|
||||
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]}
|
||||
|
||||
try:
|
||||
return res
|
||||
except Exception as e:
|
||||
print(f"Generate: Error while generating: {e}")
|
||||
|
||||
|
||||
async def send_oai_sse_event(self, data):
|
||||
if data=="[DONE]":
|
||||
self.wfile.write(f'data: {data}'.encode())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue