From 3e8b84b8e5ff3dae7d957b15b5e7aa528aa61fbd Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Tue, 22 Apr 2025 22:23:36 +0800 Subject: [PATCH] added support for structured output in chat completions --- koboldcpp.py | 85 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 31cdc53f6..a08e0b91f 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -1267,6 +1267,14 @@ def generate(genparams, stream_flag=False): ban_eos_token = genparams.get('ban_eos_token', False) stream_sse = stream_flag grammar = genparams.get('grammar', '') + #translate grammar if its json + try: + grammarjson = json.loads(grammar) + decoded = convert_json_to_gbnf(grammarjson) + if decoded: + grammar = decoded + except Exception: + pass grammar_retain_state = genparams.get('grammar_retain_state', False) genkey = genparams.get('genkey', '') trimstop = genparams.get('trim_stop', True) @@ -2051,6 +2059,32 @@ def transform_genparams(genparams, api_format): tools_message_start = adapter_obj.get("tools_start", "") tools_message_end = adapter_obj.get("tools_end", "") images_added = [] + jsongrammar = r""" +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws +string ::= + "\"" ( + [^"\\\x7F\x00-\x1F] | + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) + )* "\"" ws +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws +ws ::= | " " | "\n" [ \t]{0,20} +""" # tools handling tools_array = genparams.get('tools', []) @@ -2075,6 +2109,24 @@ def transform_genparams(genparams, api_format): tool_json_formatting_instruction = f"\nThe user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0) except Exception: # In case of any issues, just revert back to no specified function + print("Tools parsing not valid - discarded") + pass + + # handle structured outputs + respformat = genparams.get('response_format', None) + if respformat: + try: + rt = respformat.get('type') + if rt.lower() == "json_schema": + schema = respformat.get('json_schema').get('schema') + decoded = convert_json_to_gbnf(schema) + if decoded: + genparams["grammar"] = decoded + elif rt.lower() == "json_object": + genparams["grammar"] = jsongrammar + except Exception: + # In case of any issues, just do normal gen + print("Structured Output not valid - discarded") pass message_index = 0 @@ -2115,13 +2167,15 @@ def transform_genparams(genparams, api_format): # if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool? (One word response: yes or no):") # note: message string already contains the instruct start tag! + pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"' temp_poll = { "prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{user_end}", - "max_length":6, + "max_length":4, "temperature":0.1, "top_k":1, "rep_pen":1, - "ban_eos_token":False + "ban_eos_token":False, + "grammar":pollgrammar } temp_poll_result = generate(genparams=temp_poll) if temp_poll_result and "yes" not in temp_poll_result['text'].lower(): @@ -2138,32 +2192,7 @@ def transform_genparams(genparams, api_format): genparams["using_openai_tools"] = True # Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf) - genparams["grammar"] = r""" -root ::= arr -value ::= object | array | string | number | ("true" | "false" | "null") ws -arr ::= - "[\n" ws ( - value - (",\n" ws value)* - )? "]" -object ::= - "{" ws ( - string ":" ws value - ("," ws string ":" ws value)* - )? "}" ws -array ::= - "[" ws ( - value - ("," ws value)* - )? "]" ws -string ::= - "\"" ( - [^"\\\x7F\x00-\x1F] | - "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) - )* "\"" ws -number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws -ws ::= | " " | "\n" [ \t]{0,20} -""" + genparams["grammar"] = jsongrammar if message['role'] == "system": messages_string += system_message_end elif message['role'] == "user":