From c08309e7737d59ca5b3d08037de2e60bc669d2fb Mon Sep 17 00:00:00 2001 From: teddybear082 <87204721+teddybear082@users.noreply.github.com> Date: Sat, 13 Jul 2024 23:22:45 -0400 Subject: [PATCH] Rudimentary support of openai chat completions tools calls (#981) * Rudimentary support of openai chat completions tools calls -Most small models are not smart enough to do this, especially a combined tool call + role play response, but at least this allows experimentation along these lines with koboldcpp * try to also support specified function and tool choice set to none Allow tools start and end messages to be configured in adapter Try to force grammar to specific function call if specified (untested) * ensure tools get listed right after user content and before end of user message content * omit grammars approach try prompting instead -use more extensive json parsing and direct instructions to models to try to obtain the desired result -seems to work relatively well with Mistral-7B-Instruct-v.0.3.Q4_K_M.gguf and neuralhermes-2.5-mistral-7b.Q4_K_M.gguf -question of whether this is too opinionated of an approach, should the instructions be things that can be passed with the prompt template? * add back llamacpp recommended json grammar Go back to adding grammar but use "official" llamacpp grammar only not a custom one just for openai * Tidy up, remove unnecessary globals * clarity * fix missing local variable error This worked to fix the error I mentioned on my last comment --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> --- koboldcpp.py | 125 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 107 insertions(+), 18 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 5a8ddde29..a5a915d34 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -829,6 +829,33 @@ currfinishreason = "null" using_gui_launcher = False using_outdated_flags = False +# Used to parse json for openai tool calls +def extract_json_from_string(input_string): + parsed_json = None + try: # First check if model exported perfect json + parsed_json = json.loads(input_string) + return parsed_json + except Exception as e: + pass + try: # Next check if all we need is to add brackets to make it perfect json + parsed_json = json.loads(f"[{input_string}]") + return parsed_json + except Exception as e: + pass + try: + # Now use regular expression to match JSON objects or arrays in case part is valid json and part is not + json_pattern = r'(\{.*?\}|\[.*?\])' # was json_pattern = r'(\{.*\}|\[.*\])' + potential_jsons = re.findall(json_pattern, input_string, re.DOTALL) + for potential_json in potential_jsons: + try: + parsed_json = json.loads(potential_json) + return parsed_json + except Exception as e: + continue + except Exception as e: + pass + return [] + def transform_genparams(genparams, api_format): #api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate #alias all nonstandard alternative names for rep pen. @@ -873,15 +900,21 @@ def transform_genparams(genparams, api_format): user_message_end = adapter_obj.get("user_end", "") assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") assistant_message_end = adapter_obj.get("assistant_end", "") + tools_message_start = adapter_obj.get("tools_start", "") + tools_message_end = adapter_obj.get("tools_end", "") images_added = [] + message_index = 0 for message in messages_array: + message_index += 1 if message['role'] == "system": messages_string += system_message_start elif message['role'] == "user": messages_string += user_message_start elif message['role'] == "assistant": messages_string += assistant_message_start + elif message['role'] == "tool": + messages_string += tools_message_start # content can be a string or an array of objects curr_content = message['content'] @@ -894,13 +927,64 @@ def transform_genparams(genparams, api_format): elif item['type']=="image_url": if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"): images_added.append(item['image_url']['url'].split(",", 1)[1]) + # If last message, add any tools calls after message content and before message end token if any + if message['role'] == "user" and message_index == len(messages_array): + # Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None + tools_array = genparams.get('tools', []) + if tools_array and len(tools_array) > 0 and genparams.get('tool_choice',None) != None: + response_array = [{"id": "insert an id for the response", "type": "function", "function": {"name": "insert the name of the function you want to call", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}] + json_formatting_instruction = " Use this style of JSON object formatting to give your answer if you think the user is asking you to perform an action: " + json.dumps(response_array, indent=0) + tools_string = json.dumps(tools_array, indent=0) + messages_string += tools_string + specified_function = None + if isinstance(genparams.get('tool_choice'), dict): + try: + specified_function = genparams.get('tool_choice').get('function').get('name') + json_formatting_instruction = f"The user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0) + except Exception as e: + # In case of any issues, just revert back to no specified function + pass + messages_string += json_formatting_instruction + # Set temperature low automatically if function calling + genparams["temperature"] = 0.2 + genparams["using_openai_tools"] = True + + # Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf) + genparams["grammar"] = r""" +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws +string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) + )* "\"" ws +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws +ws ::= | " " | "\n" [ \t]{0,20} +""" if message['role'] == "system": messages_string += system_message_end elif message['role'] == "user": messages_string += user_message_end elif message['role'] == "assistant": messages_string += assistant_message_end + elif message['role'] == "tool": + messages_string += tools_message_end messages_string += assistant_message_start genparams["prompt"] = messages_string @@ -913,6 +997,7 @@ def transform_genparams(genparams, api_format): genparams["stop_sequence"].append(assistant_message_start.strip()) genparams["trim_stop"] = True + elif api_format==5: firstimg = genparams.get('image', "") genparams["images"] = [firstimg] @@ -963,9 +1048,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): is_quiet = args.quiet currfinishreason = "null" - def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat - - #flag instance as non-idle for a while + def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat + # flag instance as non-idle for a while washordereq = genparams.get('genkey', '').startswith('HORDEREQ_') if not washordereq: global last_non_horde_req_time @@ -1013,9 +1097,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): render_special=genparams.get('render_special', False), banned_tokens=genparams.get('banned_tokens', []), bypass_eos_token=genparams.get('bypass_eos', False), - ) + ) - genout = {"text":"","status":-1,"stopreason":-1} + genout = {"text": "", "status": -1, "stopreason": -1} if stream_flag: loop = asyncio.get_event_loop() executor = ThreadPoolExecutor() @@ -1024,9 +1108,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): genout = run_blocking() recvtxt = genout['text'] - currfinishreason = ("length" if (genout['stopreason']!=1) else "stop") + currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop") - #flag instance as non-idle for a while + # flag instance as non-idle for a while washordereq = genparams.get('genkey', '').startswith('HORDEREQ_') if not washordereq: global last_non_horde_req_time @@ -1035,27 +1119,32 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1: utfprint("\nOutput: " + recvtxt) - if api_format==1: - res = {"data": {"seqs":[recvtxt]}} - elif api_format==3: + if api_format == 1: + res = {"data": {"seqs": [recvtxt]}} + elif api_format == 3: res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname, - "usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200}, - "choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]} - elif api_format==4: + "usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200}, + "choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]} + elif api_format == 4: + using_openai_tools = genparams.get('using_openai_tools', False) + tool_calls = [] + if using_openai_tools: + tool_calls = extract_json_from_string(recvtxt) + if tool_calls and len(tool_calls)>0: + recvtxt = None res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname, - "usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200}, - "choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": currfinishreason}]} - elif api_format==5: + "usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200}, + "choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]} + elif api_format == 5: res = {"caption": end_trim_to_sentence(recvtxt)} else: - res = {"results": [{"text": recvtxt, "finish_reason":currfinishreason}]} + res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]} try: return res except Exception as e: print(f"Generate: Error while generating: {e}") - async def send_oai_sse_event(self, data): if data=="[DONE]": self.wfile.write(f'data: {data}'.encode())