tool response type coercion

This commit is contained in:
Concedo 2026-04-09 12:59:57 +08:00
parent 77d0ddb486
commit f6199d42e1

View file

@ -2929,6 +2929,68 @@ def is_ipv6_supported():
except Exception:
return False
def coerce_tool_argtypes(tool_calls: list, tool_list: list) -> list:
if not tool_calls or not tool_list:
return tool_calls
schema_map = {} #lookup correct type for the tool
for tool in tool_list:
try:
if tool.get("type") == "function":
func = tool.get("function", {})
name = func.get("name", "")
props = func.get("parameters", {}).get("properties", {})
else:
name = tool.get("name", "")
props = tool.get("parameters", {}).get("properties", {})
if name:
schema_map[name] = props
except Exception:
continue
type_coercers = {
"integer": lambda v: int(v) if not isinstance(v, int) else v,
"number": lambda v: float(v) if not isinstance(v, (int, float)) else v,
"boolean": lambda v: True if (isinstance(v, str) and v.lower() in ("true", "1", "yes")) else (False if (isinstance(v, str) and v.lower() in ("false", "0", "no")) else v),
"string": lambda v: v, # default is already string
}
result = []
for call in tool_calls:
try:
# Handle both {name, arguments} and OpenAI {type, function: {name, arguments}} formats
if "function" in call:
name = call["function"].get("name", "")
arguments = call["function"].get("arguments", {})
else:
name = call.get("name", "")
arguments = call.get("arguments", {})
props = schema_map.get(name, {})
if props and isinstance(arguments, dict):
coerced = {}
for key, val in arguments.items():
prop_type = props.get(key, {}).get("type")
coercer = type_coercers.get(prop_type)
if coercer is not None and val is not None:
try:
coerced[key] = coercer(val)
except (ValueError, AttributeError):
coerced[key] = val
else:
coerced[key] = val
# Write back
if "function" in call:
call = {**call, "function": {**call["function"], "arguments": coerced}}
else:
call = {**call, "arguments": coerced}
except Exception:
pass
result.append(call)
return result
def toolcall_to_normalized_json(text,start_tag,end_tag): #convert weird formats into standard tool call json
text = text.strip()
def parse_qwen35(text: str) -> str:
@ -3045,7 +3107,7 @@ def toolcall_to_normalized_json(text,start_tag,end_tag): #convert weird formats
return text #fallback
def repack_toolcall_tags(text: str):
def repack_toolcall_tags(text: str, original_tools:list):
tool_calls = []
if not text:
return tool_calls
@ -3076,6 +3138,7 @@ def repack_toolcall_tags(text: str):
# fallback ONLY if no tags were found at all
if not found:
tool_calls = extract_json_from_string(text)
tool_calls = coerce_tool_argtypes(tool_calls, original_tools)
return tool_calls
def format_jinja(messages_orig, tools, chat_template_kwargs=None):
@ -4365,7 +4428,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
using_openai_tools = genparams.get('using_openai_tools', False)
if using_openai_tools:
# first, check and potentially segment multiple tags for multi-tool calls
tool_calls = repack_toolcall_tags(recvtxt)
tool_calls = repack_toolcall_tags(recvtxt,genparams.get('tools', []))
if tool_calls and len(tool_calls)>0:
flat = []
for obj in tool_calls: