fix assistant prefill logic (+1 squashed commits)

Squashed commits:

[f4963baf5] fix prefills
This commit is contained in:
Concedo 2026-04-05 23:17:05 +08:00
parent 53b3bf46e4
commit 63ca37e62a

View file

@ -3104,29 +3104,26 @@ def format_jinja(messages_orig, tools, chat_template_kwargs=None):
func["arguments"] = json.loads(args)
except Exception:
pass
# Fix tool content for some templates
# if m.get("role") == "tool" and isinstance(m.get("content"), str):
# try:
# m["content"] = json.loads(m["content"])
# except Exception:
# pass
jinja_env.globals['strftime_now'] = strftime_now
jinja_env.globals['raise_exception'] = raise_exception
jinja_env.filters["tojson"] = tojson
jinja_compiled_template = jinja_env.from_string(cached_chat_template)
text = None
last_assist_msg = messages[-1]["content"]
messages_for_render = []
assist_should_prefill = False
chat_template_kwargs = chat_template_kwargs or {}
assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content
last_assist_msg = ""
if messages:
last_assist_msg = messages[-1]["content"]
assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content
last_assist_msg = "" if not assist_should_prefill else last_assist_msg
messages_for_render = messages[:-1] if assist_should_prefill else messages
if tools and len(tools)>0:
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
text = jinja_compiled_template.render(messages=messages_for_render, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
else:
text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
if assist_should_prefill and text: # handle prefill continuations
lastindex = text.rfind(last_assist_msg)
if lastindex != -1:
text = text[:lastindex + len(last_assist_msg)]
text = jinja_compiled_template.render(messages=messages_for_render, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
if assist_should_prefill and text and last_assist_msg: # handle prefill continuations
text = text + last_assist_msg
return text if text else None
except Exception as e:
print(f"Jinja formatting failed: {e}")