Rudimentary support of openai chat completions tools calls (#981)

* Rudimentary support of openai chat completions tools calls

-Most small models are not smart enough to do this, especially a combined tool call + role play response, but at least this allows experimentation along these lines with koboldcpp

* try to also support specified function and tool choice set to none

Allow tools start and end messages to be configured in adapter

Try to force grammar to specific function call if specified (untested)

* ensure tools get listed right after user content and before end of user message content

* omit grammars approach try prompting instead

-use more extensive json parsing and direct instructions to models to try to obtain the desired result

-seems to work relatively well with Mistral-7B-Instruct-v.0.3.Q4_K_M.gguf and neuralhermes-2.5-mistral-7b.Q4_K_M.gguf

-question of whether this is too opinionated of an approach, should the instructions be things that can be passed with the prompt template?

* add back llamacpp recommended json grammar

Go back to adding grammar but use "official" llamacpp grammar only not a custom one just for openai

* Tidy up, remove unnecessary globals

* clarity

* fix missing local variable error

This worked to fix the error I mentioned on my last comment

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
teddybear082 2024-07-13 23:22:45 -04:00 committed by GitHub
parent 5caf5f9770
commit c08309e773
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -829,6 +829,33 @@ currfinishreason = "null"
using_gui_launcher = False
using_outdated_flags = False
# Used to parse json for openai tool calls
def extract_json_from_string(input_string):
parsed_json = None
try: # First check if model exported perfect json
parsed_json = json.loads(input_string)
return parsed_json
except Exception as e:
pass
try: # Next check if all we need is to add brackets to make it perfect json
parsed_json = json.loads(f"[{input_string}]")
return parsed_json
except Exception as e:
pass
try:
# Now use regular expression to match JSON objects or arrays in case part is valid json and part is not
json_pattern = r'(\{.*?\}|\[.*?\])' # was json_pattern = r'(\{.*\}|\[.*\])'
potential_jsons = re.findall(json_pattern, input_string, re.DOTALL)
for potential_json in potential_jsons:
try:
parsed_json = json.loads(potential_json)
return parsed_json
except Exception as e:
continue
except Exception as e:
pass
return []
def transform_genparams(genparams, api_format):
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
#alias all nonstandard alternative names for rep pen.
@ -873,15 +900,21 @@ def transform_genparams(genparams, api_format):
user_message_end = adapter_obj.get("user_end", "")
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "")
tools_message_start = adapter_obj.get("tools_start", "")
tools_message_end = adapter_obj.get("tools_end", "")
images_added = []
message_index = 0
for message in messages_array:
message_index += 1
if message['role'] == "system":
messages_string += system_message_start
elif message['role'] == "user":
messages_string += user_message_start
elif message['role'] == "assistant":
messages_string += assistant_message_start
elif message['role'] == "tool":
messages_string += tools_message_start
# content can be a string or an array of objects
curr_content = message['content']
@ -894,13 +927,64 @@ def transform_genparams(genparams, api_format):
elif item['type']=="image_url":
if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"):
images_added.append(item['image_url']['url'].split(",", 1)[1])
# If last message, add any tools calls after message content and before message end token if any
if message['role'] == "user" and message_index == len(messages_array):
# Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
tools_array = genparams.get('tools', [])
if tools_array and len(tools_array) > 0 and genparams.get('tool_choice',None) != None:
response_array = [{"id": "insert an id for the response", "type": "function", "function": {"name": "insert the name of the function you want to call", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}]
json_formatting_instruction = " Use this style of JSON object formatting to give your answer if you think the user is asking you to perform an action: " + json.dumps(response_array, indent=0)
tools_string = json.dumps(tools_array, indent=0)
messages_string += tools_string
specified_function = None
if isinstance(genparams.get('tool_choice'), dict):
try:
specified_function = genparams.get('tool_choice').get('function').get('name')
json_formatting_instruction = f"The user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
except Exception as e:
# In case of any issues, just revert back to no specified function
pass
messages_string += json_formatting_instruction
# Set temperature low automatically if function calling
genparams["temperature"] = 0.2
genparams["using_openai_tools"] = True
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
genparams["grammar"] = r"""
root ::= arr
value ::= object | array | string | number | ("true" | "false" | "null") ws
arr ::=
"[\n" ws (
value
(",\n" ws value)*
)? "]"
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
ws ::= | " " | "\n" [ \t]{0,20}
"""
if message['role'] == "system":
messages_string += system_message_end
elif message['role'] == "user":
messages_string += user_message_end
elif message['role'] == "assistant":
messages_string += assistant_message_end
elif message['role'] == "tool":
messages_string += tools_message_end
messages_string += assistant_message_start
genparams["prompt"] = messages_string
@ -913,6 +997,7 @@ def transform_genparams(genparams, api_format):
genparams["stop_sequence"].append(assistant_message_start.strip())
genparams["trim_stop"] = True
elif api_format==5:
firstimg = genparams.get('image', "")
genparams["images"] = [firstimg]
@ -963,9 +1048,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
is_quiet = args.quiet
currfinishreason = "null"
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
#flag instance as non-idle for a while
def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
# flag instance as non-idle for a while
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
if not washordereq:
global last_non_horde_req_time
@ -1013,9 +1097,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
render_special=genparams.get('render_special', False),
banned_tokens=genparams.get('banned_tokens', []),
bypass_eos_token=genparams.get('bypass_eos', False),
)
)
genout = {"text":"","status":-1,"stopreason":-1}
genout = {"text": "", "status": -1, "stopreason": -1}
if stream_flag:
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor()
@ -1024,9 +1108,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
genout = run_blocking()
recvtxt = genout['text']
currfinishreason = ("length" if (genout['stopreason']!=1) else "stop")
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
#flag instance as non-idle for a while
# flag instance as non-idle for a while
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
if not washordereq:
global last_non_horde_req_time
@ -1035,27 +1119,32 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1:
utfprint("\nOutput: " + recvtxt)
if api_format==1:
res = {"data": {"seqs":[recvtxt]}}
elif api_format==3:
if api_format == 1:
res = {"data": {"seqs": [recvtxt]}}
elif api_format == 3:
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
elif api_format==4:
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
elif api_format == 4:
using_openai_tools = genparams.get('using_openai_tools', False)
tool_calls = []
if using_openai_tools:
tool_calls = extract_json_from_string(recvtxt)
if tool_calls and len(tool_calls)>0:
recvtxt = None
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": currfinishreason}]}
elif api_format==5:
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]}
elif api_format == 5:
res = {"caption": end_trim_to_sentence(recvtxt)}
else:
res = {"results": [{"text": recvtxt, "finish_reason":currfinishreason}]}
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]}
try:
return res
except Exception as e:
print(f"Generate: Error while generating: {e}")
async def send_oai_sse_event(self, data):
if data=="[DONE]":
self.wfile.write(f'data: {data}'.encode())