diff --git a/koboldcpp.py b/koboldcpp.py index e56a88417..2dd1c1e3d 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -2956,6 +2956,22 @@ def format_jinja(messages, tools, chat_template_kwargs=None): for m in messages: if m.get("content") is None: del m["content"] + for m in messages: # Fix tool_calls arguments and content if parsable + if m.get("tool_calls"): + for tc in m["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments") + if isinstance(args, str): + try: + func["arguments"] = json.loads(args) + except Exception: + pass + # Fix tool content + if m.get("role") == "tool" and isinstance(m.get("content"), str): + try: + m["content"] = json.loads(m["content"]) + except Exception: + pass jinja_env.globals['strftime_now'] = strftime_now jinja_env.globals['raise_exception'] = raise_exception jinja_env.filters["tojson"] = tojson