mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-08 09:59:50 +00:00
fix assistant prefill logic (+1 squashed commits)
Squashed commits: [f4963baf5] fix prefills
This commit is contained in:
parent
53b3bf46e4
commit
63ca37e62a
1 changed files with 12 additions and 15 deletions
27
koboldcpp.py
27
koboldcpp.py
|
|
@ -3104,29 +3104,26 @@ def format_jinja(messages_orig, tools, chat_template_kwargs=None):
|
|||
func["arguments"] = json.loads(args)
|
||||
except Exception:
|
||||
pass
|
||||
# Fix tool content for some templates
|
||||
# if m.get("role") == "tool" and isinstance(m.get("content"), str):
|
||||
# try:
|
||||
# m["content"] = json.loads(m["content"])
|
||||
# except Exception:
|
||||
# pass
|
||||
jinja_env.globals['strftime_now'] = strftime_now
|
||||
jinja_env.globals['raise_exception'] = raise_exception
|
||||
jinja_env.filters["tojson"] = tojson
|
||||
jinja_compiled_template = jinja_env.from_string(cached_chat_template)
|
||||
text = None
|
||||
last_assist_msg = messages[-1]["content"]
|
||||
messages_for_render = []
|
||||
assist_should_prefill = False
|
||||
chat_template_kwargs = chat_template_kwargs or {}
|
||||
assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content
|
||||
last_assist_msg = ""
|
||||
if messages:
|
||||
last_assist_msg = messages[-1]["content"]
|
||||
assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content
|
||||
last_assist_msg = "" if not assist_should_prefill else last_assist_msg
|
||||
messages_for_render = messages[:-1] if assist_should_prefill else messages
|
||||
if tools and len(tools)>0:
|
||||
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
|
||||
text = jinja_compiled_template.render(messages=messages_for_render, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
|
||||
else:
|
||||
text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
|
||||
|
||||
if assist_should_prefill and text: # handle prefill continuations
|
||||
lastindex = text.rfind(last_assist_msg)
|
||||
if lastindex != -1:
|
||||
text = text[:lastindex + len(last_assist_msg)]
|
||||
text = jinja_compiled_template.render(messages=messages_for_render, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs)
|
||||
if assist_should_prefill and text and last_assist_msg: # handle prefill continuations
|
||||
text = text + last_assist_msg
|
||||
return text if text else None
|
||||
except Exception as e:
|
||||
print(f"Jinja formatting failed: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue