rosie fixes: add format normalization for tools and tool call streaming fixes (#1842)

This commit is contained in:
LostRuins Concedo 2025-11-11 23:06:27 +08:00 committed by GitHub
parent 5125c0b879
commit 95291a93df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2362,7 +2362,7 @@ def is_ipv6_supported():
except Exception:
return False
def format_jinja(messages,tools):
def format_jinja(messages, tools):
try:
def strftime_now(format='%Y-%m-%d %H:%M:%S'):
return datetime.now().strftime(format)
@ -2374,7 +2374,11 @@ def format_jinja(messages,tools):
jinja_env.globals['strftime_now'] = strftime_now
jinja_env.filters["tojson"] = tojson
jinja_compiled_template = jinja_env.from_string(cached_chat_template)
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="")
text = None
if tools and len(tools)>0:
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="")
else:
text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="")
return text if text else None
except Exception as e:
print(f"Jinja formatting failed: {e}")
@ -2392,6 +2396,31 @@ def remove_outer_tags(inputstr):
return stripped # If no match, return original string
except Exception:
return stripped
def normalize_tool_call(obj): # Normalize various tool call formats to OpenAI format
if "type" in obj and "function" in obj: # Already in OpenAI format
return obj
if "name" in obj and ("arguments" in obj or "parameters" in obj):
args = obj.get("arguments", obj.get("parameters", {}))
return {
"type": "function",
"function": {
"name": obj["name"],
"arguments": args
}
}
if "function" in obj and isinstance(obj["function"], dict):
func = obj["function"]
if "name" in func:
return {
"type": "function",
"function": {
"name": func["name"],
"arguments": func.get("arguments", func.get("parameters", {}))
}
}
return obj
# Used to parse json for openai tool calls
def extract_json_from_string(input_string):
@ -3059,6 +3088,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if using_openai_tools:
tool_calls = extract_json_from_string(recvtxt)
if tool_calls and len(tool_calls)>0:
tool_calls = [normalize_tool_call(obj) for obj in tool_calls]
for tc in tool_calls:
tcarg = tc.get("function",{}).get("arguments",None)
tc["id"] = f"call_{random.randint(10000, 99999)}"
@ -3094,8 +3124,8 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
print(f"Generate: Error while generating: {e}")
async def send_oai_sse_event(self, data):
if data=="[DONE]":
self.wfile.write(f'data: {data}'.encode())
if data and data.strip()=="[DONE]":
self.wfile.write(f'data: {data.strip()}\n\n'.encode())
else:
self.wfile.write(f'data: {data}\n\n'.encode())
self.wfile.flush()
@ -4346,18 +4376,92 @@ Change Mode<br>
self.send_header("cache-control", "no-cache")
self.send_header("connection", "keep-alive")
self.end_headers(content_type='text/event-stream')
content_text = None
toolsdata_res = []
try:
toolsdata_res = gendat['choices'][0]['message']['tool_calls']
if toolsdata_res and len(toolsdata_res)>0:
toolsdata_res[0]["index"] = 0 # need to add an index for OWUI
toolsdata_res[0]["index"] = 0 # need to add an index for OWUI
except Exception:
toolsdata_res = []
toolsdata_p1 = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":int(time.time()),"model":friendlymodelname,"choices":[{"index":0,"finish_reason":None,"delta":{'role':'assistant','content':None, "tool_calls":toolsdata_res}}]})
toolsdata_p2 = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":int(time.time()),"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"tool_calls","delta":{}}]})
self.wfile.write(f'data: {toolsdata_p1}\n\n'.encode())
self.wfile.write(f'data: {toolsdata_p2}\n\n'.encode())
self.wfile.write('data: [DONE]'.encode())
try:
content_text = gendat['choices'][0]['message'].get('content', None)
except Exception:
content_text = None
# Send role chunk first
chunk_role = json.dumps({
"id": "koboldcpp",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": friendlymodelname,
"choices": [{"index": 0, "finish_reason": None, "delta": {"role": "assistant"}}]
})
self.wfile.write(f"data: {chunk_role}\n\n".encode())
self.wfile.flush()
# Send content if present
if content_text:
chunk_content = json.dumps({
"id": "koboldcpp",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": friendlymodelname,
"choices": [{"index": 0, "finish_reason": None, "delta": {"content": content_text}}]
})
self.wfile.write(f"data: {chunk_content}\n\n".encode())
self.wfile.flush()
# Send tool calls incrementally in OpenAI format
if toolsdata_res and len(toolsdata_res) > 0:
for idx, tool_call in enumerate(toolsdata_res):
tc_meta = {
"index": idx,
"id": tool_call.get("id", f"call_{idx}"),
"type": "function",
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": ""
}
}
chunk_meta = json.dumps({
"id": "koboldcpp",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": friendlymodelname,
"choices": [{"index": 0, "finish_reason": None, "delta": {"tool_calls": [tc_meta]}}]
})
self.wfile.write(f"data: {chunk_meta}\n\n".encode())
self.wfile.flush()
args_str = tool_call.get("function", {}).get("arguments", "{}")
if isinstance(args_str, dict):
args_str = json.dumps(args_str)
tc_args = {
"index": idx,
"function": {"arguments": args_str}
}
chunk_args = json.dumps({
"id": "koboldcpp",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": friendlymodelname,
"choices": [{"index": 0, "finish_reason": None, "delta": {"tool_calls": [tc_args]}}]
})
self.wfile.write(f"data: {chunk_args}\n\n".encode())
self.wfile.flush()
# Final chunk
chunk_final = json.dumps({
"id": "koboldcpp",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": friendlymodelname,
"choices": [{"index": 0, "finish_reason": "tool_calls", "delta": {}}]
})
self.wfile.write(f"data: {chunk_final}\n\n".encode())
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
self.close_connection = True
except Exception as ex: