diff --git a/koboldcpp.py b/koboldcpp.py index 06cd13f7e..55bfd65b1 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -3104,29 +3104,26 @@ def format_jinja(messages_orig, tools, chat_template_kwargs=None): func["arguments"] = json.loads(args) except Exception: pass - # Fix tool content for some templates - # if m.get("role") == "tool" and isinstance(m.get("content"), str): - # try: - # m["content"] = json.loads(m["content"]) - # except Exception: - # pass jinja_env.globals['strftime_now'] = strftime_now jinja_env.globals['raise_exception'] = raise_exception jinja_env.filters["tojson"] = tojson jinja_compiled_template = jinja_env.from_string(cached_chat_template) text = None - last_assist_msg = messages[-1]["content"] + messages_for_render = [] + assist_should_prefill = False chat_template_kwargs = chat_template_kwargs or {} - assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content + last_assist_msg = "" + if messages: + last_assist_msg = messages[-1]["content"] + assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content + last_assist_msg = "" if not assist_should_prefill else last_assist_msg + messages_for_render = messages[:-1] if assist_should_prefill else messages if tools and len(tools)>0: - text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs) + text = jinja_compiled_template.render(messages=messages_for_render, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs) else: - text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs) - - if assist_should_prefill and text: # handle prefill continuations - lastindex = text.rfind(last_assist_msg) - if lastindex != -1: - text = text[:lastindex + len(last_assist_msg)] + text = jinja_compiled_template.render(messages=messages_for_render, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs) + if assist_should_prefill and text and last_assist_msg: # handle prefill continuations + text = text + last_assist_msg return text if text else None except Exception as e: print(f"Jinja formatting failed: {e}")