mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 08:00:25 +00:00
tool response type coercion
This commit is contained in:
parent
77d0ddb486
commit
f6199d42e1
1 changed files with 65 additions and 2 deletions
67
koboldcpp.py
67
koboldcpp.py
|
|
@ -2929,6 +2929,68 @@ def is_ipv6_supported():
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def coerce_tool_argtypes(tool_calls: list, tool_list: list) -> list:
|
||||
if not tool_calls or not tool_list:
|
||||
return tool_calls
|
||||
|
||||
schema_map = {} #lookup correct type for the tool
|
||||
for tool in tool_list:
|
||||
try:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
name = func.get("name", "")
|
||||
props = func.get("parameters", {}).get("properties", {})
|
||||
else:
|
||||
name = tool.get("name", "")
|
||||
props = tool.get("parameters", {}).get("properties", {})
|
||||
if name:
|
||||
schema_map[name] = props
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
type_coercers = {
|
||||
"integer": lambda v: int(v) if not isinstance(v, int) else v,
|
||||
"number": lambda v: float(v) if not isinstance(v, (int, float)) else v,
|
||||
"boolean": lambda v: True if (isinstance(v, str) and v.lower() in ("true", "1", "yes")) else (False if (isinstance(v, str) and v.lower() in ("false", "0", "no")) else v),
|
||||
"string": lambda v: v, # default is already string
|
||||
}
|
||||
|
||||
result = []
|
||||
for call in tool_calls:
|
||||
try:
|
||||
# Handle both {name, arguments} and OpenAI {type, function: {name, arguments}} formats
|
||||
if "function" in call:
|
||||
name = call["function"].get("name", "")
|
||||
arguments = call["function"].get("arguments", {})
|
||||
else:
|
||||
name = call.get("name", "")
|
||||
arguments = call.get("arguments", {})
|
||||
|
||||
props = schema_map.get(name, {})
|
||||
if props and isinstance(arguments, dict):
|
||||
coerced = {}
|
||||
for key, val in arguments.items():
|
||||
prop_type = props.get(key, {}).get("type")
|
||||
coercer = type_coercers.get(prop_type)
|
||||
if coercer is not None and val is not None:
|
||||
try:
|
||||
coerced[key] = coercer(val)
|
||||
except (ValueError, AttributeError):
|
||||
coerced[key] = val
|
||||
else:
|
||||
coerced[key] = val
|
||||
# Write back
|
||||
if "function" in call:
|
||||
call = {**call, "function": {**call["function"], "arguments": coerced}}
|
||||
else:
|
||||
call = {**call, "arguments": coerced}
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
result.append(call)
|
||||
|
||||
return result
|
||||
|
||||
def toolcall_to_normalized_json(text,start_tag,end_tag): #convert weird formats into standard tool call json
|
||||
text = text.strip()
|
||||
def parse_qwen35(text: str) -> str:
|
||||
|
|
@ -3045,7 +3107,7 @@ def toolcall_to_normalized_json(text,start_tag,end_tag): #convert weird formats
|
|||
|
||||
return text #fallback
|
||||
|
||||
def repack_toolcall_tags(text: str):
|
||||
def repack_toolcall_tags(text: str, original_tools:list):
|
||||
tool_calls = []
|
||||
if not text:
|
||||
return tool_calls
|
||||
|
|
@ -3076,6 +3138,7 @@ def repack_toolcall_tags(text: str):
|
|||
# fallback ONLY if no tags were found at all
|
||||
if not found:
|
||||
tool_calls = extract_json_from_string(text)
|
||||
tool_calls = coerce_tool_argtypes(tool_calls, original_tools)
|
||||
return tool_calls
|
||||
|
||||
def format_jinja(messages_orig, tools, chat_template_kwargs=None):
|
||||
|
|
@ -4365,7 +4428,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
using_openai_tools = genparams.get('using_openai_tools', False)
|
||||
if using_openai_tools:
|
||||
# first, check and potentially segment multiple tags for multi-tool calls
|
||||
tool_calls = repack_toolcall_tags(recvtxt)
|
||||
tool_calls = repack_toolcall_tags(recvtxt,genparams.get('tools', []))
|
||||
if tool_calls and len(tool_calls)>0:
|
||||
flat = []
|
||||
for obj in tool_calls:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue