make tool calling work with jinja. but still need to fix qwen omni first (+1 squashed commits)

Squashed commits:

[e394da61e] make tool calling work with jinja. but still need to fix qwen omni first
This commit is contained in:
LostRuins Concedo 2025-11-09 15:09:29 +08:00
parent 4fc022a51f
commit 60a74bdd89
2 changed files with 23 additions and 6 deletions

View file

@ -2334,7 +2334,7 @@ def is_ipv6_supported():
except Exception:
return False
def format_jinja(messages):
def format_jinja(messages,tools):
try:
def strftime_now(format='%Y-%m-%d %H:%M:%S'):
return datetime.now().strftime(format)
@ -2346,16 +2346,30 @@ def format_jinja(messages):
jinja_env.globals['strftime_now'] = strftime_now
jinja_env.filters["tojson"] = tojson
jinja_compiled_template = jinja_env.from_string(cached_chat_template)
text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="")
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="")
return text if text else None
except Exception as e:
print("Jinja formatting failed: {e}")
return None
def remove_outer_tags(inputstr):
try:
stripped = inputstr.strip()
match = re.match(r'^<([^\s<>]+)>(.*?)</\1>\s*$', stripped, re.DOTALL) # Try angle brackets first
if match:
return match.group(2).strip()
match = re.match(r'^\[([^\s<>]+)\](.*?)\[/\1]\s*$', stripped, re.DOTALL) # Then try square brackets
if match:
return match.group(2).strip()
return stripped # If no match, return original string
except Exception:
return stripped
# Used to parse json for openai tool calls
def extract_json_from_string(input_string):
parsed_json = None
input_string = remove_outer_tags(input_string) #if we detected wrapper tags, remove them
try: # First check if model exported perfect json
parsed_json = json.loads(input_string)
if not isinstance(parsed_json, list):
@ -2665,10 +2679,13 @@ ws ::= | " " | "\n" [ \t]{0,20}
attachedimgid = 0
attachedaudid = 0
jinja_output = None
if use_jinja and cached_chat_template:
jinja_output = format_jinja(messages_array)
jinjatools = genparams.get('tools', [])
if use_jinja and cached_chat_template:
jinja_output = format_jinja(messages_array,jinjatools)
if jinja_output:
messages_string = jinja_output
if jinjatools and len(jinjatools)>0:
genparams["using_openai_tools"] = True
else:
for message in messages_array:
message_index += 1

View file

@ -34,15 +34,15 @@
#include <fstream>
#include <map>
#include <regex>
#include <sstream>
#include <numeric>
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <cinttypes>
#include <limits>
#include <numeric>
#include <array>
#include <numeric>
#include <functional>
#include <filesystem>