mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
added support for structured output in chat completions
This commit is contained in:
parent
e8b3aeaa28
commit
3e8b84b8e5
1 changed files with 57 additions and 28 deletions
85
koboldcpp.py
85
koboldcpp.py
|
@ -1267,6 +1267,14 @@ def generate(genparams, stream_flag=False):
|
||||||
ban_eos_token = genparams.get('ban_eos_token', False)
|
ban_eos_token = genparams.get('ban_eos_token', False)
|
||||||
stream_sse = stream_flag
|
stream_sse = stream_flag
|
||||||
grammar = genparams.get('grammar', '')
|
grammar = genparams.get('grammar', '')
|
||||||
|
#translate grammar if its json
|
||||||
|
try:
|
||||||
|
grammarjson = json.loads(grammar)
|
||||||
|
decoded = convert_json_to_gbnf(grammarjson)
|
||||||
|
if decoded:
|
||||||
|
grammar = decoded
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
grammar_retain_state = genparams.get('grammar_retain_state', False)
|
grammar_retain_state = genparams.get('grammar_retain_state', False)
|
||||||
genkey = genparams.get('genkey', '')
|
genkey = genparams.get('genkey', '')
|
||||||
trimstop = genparams.get('trim_stop', True)
|
trimstop = genparams.get('trim_stop', True)
|
||||||
|
@ -2051,6 +2059,32 @@ def transform_genparams(genparams, api_format):
|
||||||
tools_message_start = adapter_obj.get("tools_start", "")
|
tools_message_start = adapter_obj.get("tools_start", "")
|
||||||
tools_message_end = adapter_obj.get("tools_end", "")
|
tools_message_end = adapter_obj.get("tools_end", "")
|
||||||
images_added = []
|
images_added = []
|
||||||
|
jsongrammar = r"""
|
||||||
|
root ::= arr
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
arr ::=
|
||||||
|
"[\n" ws (
|
||||||
|
value
|
||||||
|
(",\n" ws value)*
|
||||||
|
)? "]"
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\\x7F\x00-\x1F] |
|
||||||
|
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||||
|
)* "\"" ws
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
|
||||||
|
ws ::= | " " | "\n" [ \t]{0,20}
|
||||||
|
"""
|
||||||
|
|
||||||
# tools handling
|
# tools handling
|
||||||
tools_array = genparams.get('tools', [])
|
tools_array = genparams.get('tools', [])
|
||||||
|
@ -2075,6 +2109,24 @@ def transform_genparams(genparams, api_format):
|
||||||
tool_json_formatting_instruction = f"\nThe user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
tool_json_formatting_instruction = f"\nThe user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
|
||||||
except Exception:
|
except Exception:
|
||||||
# In case of any issues, just revert back to no specified function
|
# In case of any issues, just revert back to no specified function
|
||||||
|
print("Tools parsing not valid - discarded")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# handle structured outputs
|
||||||
|
respformat = genparams.get('response_format', None)
|
||||||
|
if respformat:
|
||||||
|
try:
|
||||||
|
rt = respformat.get('type')
|
||||||
|
if rt.lower() == "json_schema":
|
||||||
|
schema = respformat.get('json_schema').get('schema')
|
||||||
|
decoded = convert_json_to_gbnf(schema)
|
||||||
|
if decoded:
|
||||||
|
genparams["grammar"] = decoded
|
||||||
|
elif rt.lower() == "json_object":
|
||||||
|
genparams["grammar"] = jsongrammar
|
||||||
|
except Exception:
|
||||||
|
# In case of any issues, just do normal gen
|
||||||
|
print("Structured Output not valid - discarded")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
message_index = 0
|
message_index = 0
|
||||||
|
@ -2115,13 +2167,15 @@ def transform_genparams(genparams, api_format):
|
||||||
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
|
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
|
||||||
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool? (One word response: yes or no):")
|
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool? (One word response: yes or no):")
|
||||||
# note: message string already contains the instruct start tag!
|
# note: message string already contains the instruct start tag!
|
||||||
|
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
|
||||||
temp_poll = {
|
temp_poll = {
|
||||||
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{user_end}",
|
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{user_end}",
|
||||||
"max_length":6,
|
"max_length":4,
|
||||||
"temperature":0.1,
|
"temperature":0.1,
|
||||||
"top_k":1,
|
"top_k":1,
|
||||||
"rep_pen":1,
|
"rep_pen":1,
|
||||||
"ban_eos_token":False
|
"ban_eos_token":False,
|
||||||
|
"grammar":pollgrammar
|
||||||
}
|
}
|
||||||
temp_poll_result = generate(genparams=temp_poll)
|
temp_poll_result = generate(genparams=temp_poll)
|
||||||
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
|
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
|
||||||
|
@ -2138,32 +2192,7 @@ def transform_genparams(genparams, api_format):
|
||||||
genparams["using_openai_tools"] = True
|
genparams["using_openai_tools"] = True
|
||||||
|
|
||||||
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
|
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
|
||||||
genparams["grammar"] = r"""
|
genparams["grammar"] = jsongrammar
|
||||||
root ::= arr
|
|
||||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
|
||||||
arr ::=
|
|
||||||
"[\n" ws (
|
|
||||||
value
|
|
||||||
(",\n" ws value)*
|
|
||||||
)? "]"
|
|
||||||
object ::=
|
|
||||||
"{" ws (
|
|
||||||
string ":" ws value
|
|
||||||
("," ws string ":" ws value)*
|
|
||||||
)? "}" ws
|
|
||||||
array ::=
|
|
||||||
"[" ws (
|
|
||||||
value
|
|
||||||
("," ws value)*
|
|
||||||
)? "]" ws
|
|
||||||
string ::=
|
|
||||||
"\"" (
|
|
||||||
[^"\\\x7F\x00-\x1F] |
|
|
||||||
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
|
||||||
)* "\"" ws
|
|
||||||
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
|
|
||||||
ws ::= | " " | "\n" [ \t]{0,20}
|
|
||||||
"""
|
|
||||||
if message['role'] == "system":
|
if message['role'] == "system":
|
||||||
messages_string += system_message_end
|
messages_string += system_message_end
|
||||||
elif message['role'] == "user":
|
elif message['role'] == "user":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue