diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp new file mode 100644 index 000000000..01787cc8d --- /dev/null +++ b/common/chat-auto-parser-generator.cpp @@ -0,0 +1,442 @@ +#include "chat-auto-parser.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "json-schema-to-grammar.h" +#include "nlohmann/json.hpp" + +#include +#include + +using json = nlohmann::ordered_json; + +// Helper to iterate over tools/functions +static void foreach_function(const json & tools, const std::function & fn) { + for (const auto & tool : tools) { + if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { + continue; + } + fn(tool); + } +} + +namespace autoparser { + +parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) : + p(p), + inputs(inputs), + reasoning_parser(p.eps()) {} + +common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs) { + // Run differential analysis to extract template structure + struct autoparser autoparser; + autoparser.analyze_template(tmpl); + return generate_parser(tmpl, inputs, autoparser); +} + +common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs, + const autoparser & autoparser) { + // Build the parser using the analysis results + auto parser = autoparser.build_parser(inputs); + + // Create the result structure + common_chat_params data; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = autoparser.preserved_tokens; + data.parser = parser.save(); + + // Build grammar if tools are present + bool has_tools = + autoparser.tools.format.mode != tool_format::NONE && inputs.tools.is_array() && !inputs.tools.empty(); + std::string trigger_marker = !autoparser.tools.format.section_start.empty() ? autoparser.tools.format.section_start : + autoparser.tools.format.per_call_start; + bool include_grammar = + has_tools && ((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) || + inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED); + + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + // Set grammar triggers based on tool section markers (fall back to per-call markers) + if (data.grammar_lazy) { // only do triggers on lazy grammar + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker } + }; + } + } + + return data; +} + +common_peg_arena autoparser::build_parser(const templates_params & inputs) const { + if (!analysis_complete) { + throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); + } + return build_chat_peg_parser([&](common_chat_peg_builder & p) { + // If the template uses Python dict format (single-quoted strings in JSON structures), + // pre-register a json-string rule that accepts both quote styles. This must happen + // before any call to p.json() so that all JSON parsing inherits the flexible rule. + if (tools.format.uses_python_dicts) { + p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); }); + } + + parser_build_context ctx(p, inputs); + bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + bool enable_thinking = inputs.enable_thinking; + + ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE; + ctx.content = &content; + + // Build reasoning parser + ctx.reasoning_parser = reasoning.build_parser(ctx); + + bool has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty(); + + if (has_response_format) { + return ctx.reasoning_parser + p.space() + + p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + p.end(); + } + + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) { + return tools.build_parser(ctx); + } + + return content.build_parser(ctx); + }); +} + +common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) const { + auto & p = ctx.p; + + if (!ctx.extracting_reasoning) { + return p.eps(); + } + + bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN); + bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED); + + if (thinking_forced_open || thinking_forced_closed) { + // Thinking is forced open OR forced closed with enable_thinking=true + // In both cases, expect only the closing tag (opening was in template) + return p.reasoning(p.until(end)) + end; + } + if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { + // Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) + // Both use the same tag-based pattern if markers are available + if (!start.empty() && !end.empty()) { + return p.optional(start + p.reasoning(p.until(end)) + end); + } + } else if (mode == reasoning_mode::DELIMITER) { + return p.optional(p.reasoning(p.until(end)) + end); + } + + return p.eps(); +} + +common_peg_parser analyze_content::build_parser(parser_build_context & ctx) const { + auto & p = ctx.p; + + if (is_always_wrapped()) { + if (ctx.extracting_reasoning) { + return ctx.reasoning_parser + start + p.content(p.until(end)) + end + p.end(); + } + return p.content(p.until(start)) + start + p.content(p.until(end)) + end + p.end(); + } + return ctx.reasoning_parser + p.content(p.rest()) + p.end(); +} + +common_peg_parser analyze_content::build_optional_wrapped(parser_build_context & ctx) const { + auto & p = ctx.p; + + if (is_always_wrapped()) { + return p.optional(start + p.content(p.until(end)) + end); + } + return p.eps(); +} + +common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const { + switch (format.mode) { + case tool_format::JSON_NATIVE: + return build_tool_parser_json_native(ctx); + case tool_format::TAG_WITH_JSON: + return build_tool_parser_tag_json(ctx); + case tool_format::TAG_WITH_TAGGED: + return build_tool_parser_tag_tagged(ctx); + default: + GGML_ABORT("Unable to create tool parser"); + } +} + +common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + // Build effective field names with dot notation if function_field is set + std::string name_field = format.name_field; + std::string args_field = format.args_field; + + if (!format.function_field.empty() && format.function_field != "function" && + name_field.find('.') == std::string::npos) { + name_field = format.function_field + "." + name_field; + args_field = format.function_field + "." + args_field; + } + + auto tools_parser = p.standard_json_tools( + format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls, + inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, + format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + + // Handle content wrappers if present + if (ctx.content && ctx.content->is_always_wrapped()) { + auto wrapped_content = ctx.content->build_optional_wrapped(ctx); + return ctx.reasoning_parser + wrapped_content + tools_parser + p.end(); + } + + std::string tool_start = "{"; + if (!format.section_start.empty()) { + tool_start = format.section_start; + } else if (!format.per_call_start.empty()) { + tool_start = format.per_call_start; + } + + return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser + + p.end(); +} + +common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_choice = p.choice(); + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & func = tool.at("function"); + std::string name = func.at("name"); + const auto & schema = func.at("parameters"); + + // Build call_id parser based on position (if supported) + common_peg_parser call_id_section = p.eps(); + if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() && + !call_id.suffix.empty()) { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix; + } + + auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)); + if (!function.close.empty()) { + func_parser = func_parser + function.close; + } + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + auto require_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_calls = p.eps(); + + if (!format.per_call_start.empty()) { + auto wrapped_call = format.per_call_start + tool_choice + format.per_call_end; + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + } else { + tool_calls = p.trigger_rule("tool-call", wrapped_call); + } + if (!format.section_start.empty()) { + tool_calls = p.trigger_rule("tool-calls", + p.literal(format.section_start) + p.space() + tool_calls + p.space() + + (format.section_end.empty() ? p.end() : p.literal(format.section_end))); + } + } else { + std::string separator = ", "; // Default + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice + + p.zero_or_more(separator + tool_choice) + format.section_end); + } else { + tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice + format.section_end); + } + } + + if (!require_calls) { + tool_calls = p.optional(tool_calls); + } + + std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; + auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker); + return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls + + p.end(); +} + +common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_choice = p.choice(); + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & func = tool.at("function"); + std::string name = func.at("name"); + const auto & params = func.at("parameters"); + + if (!params.contains("properties") || !params.at("properties").is_object()) { + return; + } + + const auto & properties = params.at("properties"); + std::set required; + if (params.contains("required") && params.at("required").is_array()) { + params.at("required").get_to(required); + } + + // Build parser for each argument, separating required and optional + std::vector required_parsers; + std::vector optional_parsers; + for (const auto & [param_name, param_schema] : properties.items()) { + bool is_required = required.find(param_name) != required.end(); + std::string type = "object"; + auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object(); + if (type_obj.is_string()) { + type_obj.get_to(type); + } else if (type_obj.is_object()) { + if (type_obj.contains("type") && type_obj.at("type").is_string()) { + type_obj.at("type").get_to(type); + } + } + + auto arg = p.tool_arg( + p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) + + arguments.name_suffix) + + arguments.value_prefix + + (type == "string" ? p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, true)) : + p.tool_arg_json_value(p.schema( + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) + + p.space()) + + p.tool_arg_close(p.literal(arguments.value_suffix))); + + auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg); + if (is_required) { + required_parsers.push_back(named_arg); + } else { + optional_parsers.push_back(named_arg); + } + } + + // Build required arg sequence in definition order + common_peg_parser args_seq = p.eps(); + for (size_t i = 0; i < required_parsers.size(); i++) { + if (i > 0) { + args_seq = args_seq + p.space(); + } + args_seq = args_seq + required_parsers[i]; + } + + // Build optional args with flexible ordering + if (!optional_parsers.empty()) { + common_peg_parser any_opt = p.choice(); + for (const auto & opt : optional_parsers) { + any_opt |= opt; + } + args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size()); + } + + // Build call_id parser based on position (if supported) + common_peg_parser call_id_section = p.eps(); + bool have_call_id = false; + if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() && + !call_id.suffix.empty()) { + have_call_id = true; + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix); + } + + bool matched_atomic = false; + common_peg_parser func_parser = p.eps(); + if (!function.name_suffix.empty()) { + func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section + p.space() + args_seq; + matched_atomic = true; + } else if (have_call_id) { + func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section) + p.space() + args_seq; + matched_atomic = true; + } else if (!arguments.name_prefix.empty() && properties.size() > 0) { + func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq; + matched_atomic = true; + } else { + func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section + p.space() + args_seq; + } + + if (!function.close.empty()) { + func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close)); + } else if (!format.per_call_end.empty()) { + // When there's no func_close but there is a per_call_end marker, use peek() to ensure + // we only emit tool_close when we can actually see the closing marker. This prevents + // premature closing during partial parsing when we've seen e.g. "" (end) or "" prefix that failed to match. + func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end))); + } else { + func_parser = + func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper + } + if (!matched_atomic) { + func_parser = p.atomic(func_parser); + } + + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_calls = p.eps(); + + if (!format.per_call_start.empty()) { + auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end; + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + } else { + tool_calls = p.trigger_rule("tool-call", wrapped_call); + } + if (!format.section_start.empty()) { + tool_calls = p.trigger_rule("tool-calls", + p.literal(format.section_start) + p.space() + tool_calls + p.space() + + (format.section_end.empty() ? p.end() : p.literal(format.section_end))); + } + } else { + std::string separator = ", "; // Default + + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", format.section_start + p.space() + tool_choice + + p.zero_or_more(separator + tool_choice) + p.space() + + format.section_end); + } else { + tool_calls = p.trigger_rule( + "tool-call", format.section_start + p.space() + tool_choice + p.space() + format.section_end); + } + } + + if (!require_tools) { + tool_calls = p.optional(tool_calls); + } + + std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; + auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker); + return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls + + p.end(); +} + +} // namespace autoparser diff --git a/common/chat-auto-parser-helpers.cpp b/common/chat-auto-parser-helpers.cpp new file mode 100644 index 000000000..1519d8bc6 --- /dev/null +++ b/common/chat-auto-parser-helpers.cpp @@ -0,0 +1,347 @@ +#include "chat-auto-parser-helpers.h" + +#include "chat-auto-parser.h" +#include "chat.h" +#include "log.h" +#include "nlohmann/json.hpp" + +#include +#include + +using json = nlohmann::ordered_json; + +std::string trim_whitespace(const std::string & str) { + size_t start = 0; + while (start < str.length() && std::isspace(static_cast(str[start]))) { + start++; + } + + if (start == str.length()) { + return ""; + } + + size_t end = str.length() - 1; + while (end > start && std::isspace(static_cast(str[end]))) { + end--; + } + + return str.substr(start, end - start + 1); +} + +std::string trim_leading_whitespace(const std::string & str) { + size_t start = 0; + while (start < str.length() && std::isspace(static_cast(str[start]))) { + start++; + } + + return str.substr(start); +} + +std::string trim_trailing_whitespace(const std::string & str) { + if (str.empty()) { + return ""; + } + + size_t end = str.length() - 1; + while (end > 0 && std::isspace(static_cast(str[end]))) { + end--; + } + + // If first char is also whitespace, return empty string + if (end == 0 && std::isspace(static_cast(str[0]))) { + return ""; + } + + return str.substr(0, end + 1); +} + +std::string trim_trailing_newlines(const std::string & str) { + size_t end = str.length(); + while (end > 0 && str[end - 1] == '\n') { + end--; + } + + return str.substr(0, end); +} + +static size_t common_prefix_len(const std::string & left, const std::string & right) { + size_t prefix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (prefix_len < min_len && left[prefix_len] == right[prefix_len]) { + prefix_len++; + } + return prefix_len; +} + +static size_t common_suffix_len(const std::string & left, const std::string & right) { + size_t suffix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (suffix_len < min_len && left[left.length() - 1 - suffix_len] == right[right.length() - 1 - suffix_len]) { + suffix_len++; + } + return suffix_len; +} + +diff_split calculate_diff_split(const std::string & left, const std::string & right) { + diff_split result; + + auto left_seg = segmentize_markers(left); + auto right_seg = segmentize_markers(right); + + if (left_seg.empty()) { + result.right = right; + return result; + } + if (right_seg.empty()) { + result.left = left; + return result; + } + + auto left_start = left_seg.begin(); + auto left_end = --left_seg.end(); + auto right_start = right_seg.begin(); + auto right_end = --right_seg.end(); + + auto test = [&] () { + return left_start != left_end && right_start != right_end; + }; + + bool left_fully_consumed = false; + bool right_fully_consumed = false; + + while (test()) { + bool advanced = false; + if (*left_start == *right_start) { + result.prefix.append(left_start->value); + left_start++; + right_start++; + advanced = true; + } + if (*left_end == *right_end) { + result.suffix = left_end->value + result.suffix; + if (left_start != left_end) { + left_end--; + } else { + left_fully_consumed = true; + } + if (right_start != right_end) { + right_end--; + } else { + right_fully_consumed = true; + } + advanced = true; + } + if (!advanced) { + break; + } + } + + if (left_start == left_end && right_start != right_end) { + if (*left_start == *right_end) { + result.suffix = right_end->value + result.suffix; + right_end--; + left_fully_consumed = true; + } else if (*left_start == *right_start) { + result.prefix.append(right_start->value); + right_start++; + left_fully_consumed = true; + } + } else if (right_start == right_end && left_start != left_end) { + if (*left_end == *right_start) { + result.suffix = left_end->value + result.suffix; + left_end--; + right_fully_consumed = true; + } else if (*left_start == *right_start) { + result.prefix.append(left_start->value); + left_start++; + right_fully_consumed = true; + } + } else if (left_start == left_end && right_start == right_end && *left_start == *right_start && left_start->type == segment_type::MARKER) { + result.prefix.append(right_start->value); + left_fully_consumed = true; + right_fully_consumed = true; + } + + auto eat_segment = [](std::string & str, segment & seg) -> std::string { return str.append(seg.value); }; + + bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT; + bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT; + + std::string remainder_left = std::accumulate(left_start, left_fully_consumed ? left_end : ++left_end, std::string(), eat_segment); + std::string remainder_right = std::accumulate(right_start, right_fully_consumed ? right_end : ++right_end, std::string(), eat_segment); + + size_t suffix_len = can_have_text_suffix ? common_suffix_len(remainder_left, remainder_right) : 0; + // avoid overlaps between prefix and suffix + size_t prefix_len = can_have_text_prefix ? common_prefix_len(remainder_left.substr(0, remainder_left.size() - suffix_len), + remainder_right.substr(0, remainder_right.size() - suffix_len)) : 0; + + result.prefix.append(remainder_left.substr(0, prefix_len)); + result.suffix = remainder_left.substr(remainder_left.length() - suffix_len, suffix_len) + result.suffix; + result.left = remainder_left.substr(prefix_len, remainder_left.length() - prefix_len - suffix_len); + result.right = remainder_right.substr(prefix_len, remainder_right.length() - prefix_len - suffix_len); + + if (result.left == "" && result.right == "") { + // degenerate case, no diff + result.prefix = left; + result.suffix = ""; + // pick prefix = all as representation + } + return result; +} + +// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right` +std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right) { + // Find the common prefix of left and right + size_t common_prefix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (common_prefix_len < min_len && left[common_prefix_len] == right[common_prefix_len]) { + common_prefix_len++; + } + + // If there's no common prefix, return empty string + if (common_prefix_len == 0) { + return ""; + } + + // Find the common prefix in the full string + std::string common_prefix = left.substr(0, common_prefix_len); + size_t pos = full.find(common_prefix); + + // If not found, return empty string + if (pos == std::string::npos) { + return ""; + } + + // Return everything before the common prefix + return full.substr(0, pos); +} + +// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right` +std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right) { + // Find the common suffix of left and right (compare from the end) + size_t common_suffix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (common_suffix_len < min_len && + left[left.length() - 1 - common_suffix_len] == right[right.length() - 1 - common_suffix_len]) { + common_suffix_len++; + } + + // If there's no common suffix, return empty string + if (common_suffix_len == 0) { + return ""; + } + + // Extract the common suffix + std::string common_suffix = left.substr(left.length() - common_suffix_len); + + // Find the last occurrence of the common suffix in the full string + size_t pos = full.rfind(common_suffix); + + // If not found, return empty string + if (pos == std::string::npos) { + return ""; + } + + // Return everything after the common suffix + return full.substr(pos + common_suffix_len); +} + +// TODO: segmentize will treat a JSON array inside tags as a tag: [{ "fun": { ... } }] will be three markers +// not too worried about that because it hasn't turned out as a problem anywhere, but noting here in case it will +// Might have to put some restrictions on tag contents as well (like "no { }") +std::vector segmentize_markers(const std::string & text) { + std::vector retval; + bool in_marker = false; + char marker_opener = '\0'; + + auto is_marker_opener = [](char c) -> bool { return c == '<' || c == '['; }; + auto is_marker_closer = [](char op, char c) -> bool { return (op == '<' && c == '>') || (op == '[' && c == ']'); }; + + size_t last_border = 0; + + for (size_t cur_pos = 0; cur_pos < text.length(); cur_pos++) { + if (!in_marker && is_marker_opener(text[cur_pos])) { + if (last_border < cur_pos) { + retval.push_back(segment(segment_type::TEXT, text.substr(last_border, cur_pos - last_border))); + } + last_border = cur_pos; + in_marker = true; + marker_opener = text[cur_pos]; + } else if (in_marker && is_marker_closer(marker_opener, text[cur_pos])) { + // no need to check because last_border will always be smaller + retval.push_back(segment(segment_type::MARKER, text.substr(last_border, cur_pos - last_border + 1))); + last_border = cur_pos + 1; + in_marker = false; + marker_opener = '\0'; + } + } + if (last_border < text.length()) { + retval.push_back(segment(segment_type::TEXT, text.substr(last_border))); + } + return retval; +} + +std::vector prune_whitespace_segments(const std::vector & segments) { + std::vector result; + for (const auto & seg : segments) { + if (!trim_whitespace(seg.value).empty()) { + result.push_back(seg); + } + } + return result; +} + +namespace autoparser { + +std::string apply_template(const common_chat_template & tmpl, const template_params & params) { + templates_params tmpl_params; + tmpl_params.messages = params.messages; + tmpl_params.tools = params.tools; + tmpl_params.add_generation_prompt = params.add_generation_prompt; + tmpl_params.enable_thinking = params.enable_thinking; + + if (params.extra_context) { + tmpl_params.extra_context = *params.extra_context; + } + tmpl_params.extra_context["enable_thinking"] = params.enable_thinking; + + try { + return common_chat_template_direct_apply(tmpl, tmpl_params); + } catch (const std::exception & e) { + LOG_DBG("Template application failed: %s\n", e.what()); + return ""; + } +} + +std::optional compare_variants( + const common_chat_template & tmpl, + const template_params & params_A, + const std::function & params_modifier) { + // Create variant B by copying A + template_params params_B = params_A; + + // Apply modifier to create variant B + if (params_modifier) { + params_modifier(params_B); + } + + // Apply template to both variants + std::string output_A = apply_template(tmpl, params_A); + std::string output_B = apply_template(tmpl, params_B); + + // Check for template application failures + if (output_A.empty() || output_B.empty()) { + return std::nullopt; + } + + // Calculate diff and return result with both outputs + compare_variants_result result; + result.diff = calculate_diff_split(output_A, output_B); + result.output_A = output_A; + result.output_B = output_B; + + return result; +} + +} // namespace autoparser + diff --git a/common/chat-auto-parser-helpers.h b/common/chat-auto-parser-helpers.h new file mode 100644 index 000000000..6e3df79db --- /dev/null +++ b/common/chat-auto-parser-helpers.h @@ -0,0 +1,73 @@ +#pragma once + +#include "chat-auto-parser.h" +#include +#include +#include + +std::string trim_whitespace(const std::string & str); +std::string trim_leading_whitespace(const std::string & str); +std::string trim_trailing_whitespace(const std::string & str); +std::string trim_trailing_newlines(const std::string & str); + +// calculate a diff split (longest common prefix, longest common suffix excluding prefix, +// mismatched part on the left, mismatched part on the right) between two strings +// account for markers - align prefix and suffix endings so that they end on markers +// * eg.: +// calculate_diff_split("
", "

Something

") -> +// { "prefix": "" (not: "<"), "suffix": "", "left": "
", "right": "

Something

" } +// calculate_diff_split("Something", "") -> +// { "prefix": "", "suffix": "", "left": "Something", "right": "" } +diff_split calculate_diff_split(const std::string & left, const std::string & right); + +// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right` +// Returns empty string if there's no common prefix +// * eg.: +// until_common_prefix("really want a FUNCTION call", "FUNCTION alpha", "FUNCTION beta") -> "really want a " +// until_common_prefix("", "", "") -> "" +// until_common_prefix("some text", "1234", "abcd") -> "" +// until_common_prefix("one arg two args three args four", "argument alpha", "argument beta") -> "one "" +std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right); + +// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right` +// Returns empty string if there's no common suffix +// Mirror function of `until_common_prefix` +// * eg.: +// after_common_suffix("really want a FUNCTION call", "first FUNCTION", "second FUNCTION") -> " call" +// after_common_suffix("one arg two-args three args four", "alpha-args", "beta-args") -> " three args four" +std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right); + +// Segmentize text into markers and non-marker fragments +// * eg.: +// segmentize_markers("The site title
Here's some content
" -> +// [ (MARKER, ""), (MARKER, ""), (MARKER, ""), (TEXT, "The site title"), (MARKER, ""), +// (MARKER, ""), (MARKER, "
"), (TEXT, "Here's some "), (MARKER, ""), (TEXT, "content"), (MARKER, ""), +// (MARKER, "
"), (MARKER, ""), (MARKER, "") +// ] +// segmentize_markers("<|tool_call|>[args]{ are here }[/args]<|tool_call_end|>") -> +// [ (MARKER, "<|tool_call|>"), (MARKER, "[args]"), (TEXT, "{ are here }"), (MARKER, "[/args]"), (MARKER, "<|tool_call_end|>") ] +std::vector segmentize_markers(const std::string & text); + +// Prune whitespace-only segments from a vector of segments +// * eg.: +// segmentize_markers("\n\n\n \n\n\n") -> +// X = [ (MARKER, ""), (TEXT, "\n"), (MARKER, ""), (TEXT, "\n"), (MARKER, ""), (TEXT, "\n \n"), +// (MARKER, ""), (TEXT, "\n"), (MARKER, ""), (TEXT, "\n"), (MARKER, "") ] +// prune_whitespace_segments(X) -> [ (MARKER, ""), (MARKER, ""), (MARKER, ""), (MARKER, ""), +// (MARKER, ""), (MARKER, "") ] +std::vector prune_whitespace_segments(const std::vector & segments); + +namespace autoparser { + +// Apply a template with the given parameters, returning the rendered string (empty on failure) +std::string apply_template(const common_chat_template & tmpl, const template_params & params); + +// Factorized differential comparison function +// Takes base params and a single modifier lambda to create variant B +// Returns compare_variants_result containing diff and both outputs, or std::nullopt on failure +std::optional compare_variants( + const common_chat_template & tmpl, + const template_params & params_A, + const std::function & params_modifier); + +} // namespace autoparser diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h new file mode 100644 index 000000000..52c6488f4 --- /dev/null +++ b/common/chat-auto-parser.h @@ -0,0 +1,433 @@ +#pragma once + +#include "chat.h" +#include "common.h" +#include "jinja/caps.h" +#include "peg-parser.h" + +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +class common_chat_peg_builder; + +// ============================================================================ +// Parameters for template application (low-level, used by diff analysis) +// ============================================================================ +struct template_params { + json messages; + json tools; + bool add_generation_prompt = false; + bool enable_thinking = true; + std::optional extra_context = std::nullopt; +}; + +struct diff_split { + std::string prefix; + std::string suffix; + std::string left; + std::string right; + + bool operator==(struct diff_split & other) const { + return prefix == other.prefix && suffix == other.suffix && left == other.left && right == other.right; + } +}; + +// Result of compare_variants containing diff and original outputs +struct compare_variants_result { + diff_split diff; + std::string output_A; + std::string output_B; +}; + +namespace autoparser { + +// ============================================================================ +// High-level params for parser generation +// ============================================================================ + +struct templates_params { + json messages; + json tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + json json_schema; + bool parallel_tool_calls = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; + bool stream = true; + std::string grammar; + bool add_generation_prompt = false; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + json extra_context; + bool add_bos = false; + bool add_eos = false; + bool is_inference = true; + bool add_inference = false; + bool mark_input = true; // whether to mark input strings in the jinja context +}; + +// ============================================================================ +// Analysis Result Enums +// ============================================================================ + +// Reasoning handling mode (derived from R1-R3 comparisons) +enum class reasoning_mode { + NONE, // No reasoning markers detected + TAG_BASED, // Standard tag-based: ... + DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter) + FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end) + FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but + // with both opened and closed tag for disabled thinking + TOOLS_ONLY // Only reason on tool calls, not on normal content +}; + +inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode) { + switch (mode) { + case reasoning_mode::NONE: + return os << "NONE"; + case reasoning_mode::TAG_BASED: + return os << "TAG_BASED"; + case reasoning_mode::DELIMITER: + return os << "DELIMITER"; + case reasoning_mode::FORCED_OPEN: + return os << "FORCED_OPEN"; + case reasoning_mode::FORCED_CLOSED: + return os << "FORCED_CLOSED"; + case reasoning_mode::TOOLS_ONLY: + return os << "TOOLS_ONLY"; + default: + return os << "UNKNOWN"; + } +} + +// Content wrapping mode (derived from C1 comparison) +enum class content_mode { + PLAIN, // No content markers + ALWAYS_WRAPPED, // Content always wrapped with markers + WRAPPED_WITH_REASONING, // Content wrapped only when reasoning present +}; + +inline std::ostream & operator<<(std::ostream & os, const content_mode & mode) { + switch (mode) { + case content_mode::PLAIN: + return os << "PLAIN"; + case content_mode::ALWAYS_WRAPPED: + return os << "ALWAYS_WRAPPED"; + case content_mode::WRAPPED_WITH_REASONING: + return os << "WRAPPED_WITH_REASONING"; + default: + return os << "UNKNOWN"; + } +} + +// Call ID position in tool calls (for non-JSON formats) +enum class call_id_position { + NONE, // No call ID support detected + PRE_FUNC_NAME, // Call ID before function name: [CALL_ID]id[FUNC]name{args} + BETWEEN_FUNC_AND_ARGS, // Call ID between function and args: [FUNC]name[CALL_ID]id{args} + POST_ARGS, // Call ID after arguments: [FUNC]name{args}[CALL_ID]id +}; + +inline std::ostream & operator<<(std::ostream & os, const call_id_position & pos) { + switch (pos) { + case call_id_position::NONE: + return os << "NONE"; + case call_id_position::PRE_FUNC_NAME: + return os << "PRE_FUNC_NAME"; + case call_id_position::BETWEEN_FUNC_AND_ARGS: + return os << "BETWEEN_FUNC_AND_ARGS"; + case call_id_position::POST_ARGS: + return os << "POST_ARGS"; + default: + return os << "UNKNOWN"; + } +} + +// Tool call format classification (derived from T1-T5, A1-A3 comparisons) +enum class tool_format { + NONE, // No tool support detected + JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}} + TAG_WITH_JSON, // Tag-based with JSON args: {...} + TAG_WITH_TAGGED, // Tag-based with tagged args: value +}; + +inline std::ostream & operator<<(std::ostream & os, const tool_format & format) { + switch (format) { + case tool_format::NONE: + return os << "NONE"; + case tool_format::JSON_NATIVE: + return os << "JSON_NATIVE"; + case tool_format::TAG_WITH_JSON: + return os << "TAG_WITH_JSON"; + case tool_format::TAG_WITH_TAGGED: + return os << "TAG_WITH_TAGGED"; + default: + return os << "UNKNOWN"; + } +} + +// ============================================================================ +// Sub-structs for tool analysis +// ============================================================================ + +struct tool_format_analysis { + tool_format mode = tool_format::NONE; + + std::string section_start; // e.g., "", "[TOOL_CALLS]", "" + std::string section_end; // e.g., "", "" + std::string per_call_start; // e.g., "<|tool_call_begin|>", "" (for multi-call templates) + std::string per_call_end; // e.g., "<|tool_call_end|>", "" + + bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "": { ... arguments ... } } + bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...] + bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings) + + std::string function_field = "function"; + std::string name_field = "name"; + std::string args_field = "arguments"; + std::string id_field; + std::string gen_id_field; + std::vector parameter_order; +}; + +struct tool_function_analysis { + std::string name_prefix; // e.g., "", "\"", ":0" + std::string close; // e.g., "", "" (for tag-based) +}; + +struct tool_arguments_analysis { + std::string start; // e.g., "<|tool_call_argument_begin|>", "" + std::string end; // e.g., "<|tool_call_argument_end|>", "" + std::string name_prefix; // e.g., "", "\"" + std::string name_suffix; // e.g., ">", "
", "\":" + std::string value_prefix; // e.g., "", "", "" + std::string value_suffix; // e.g., "", "", "" + std::string separator; // e.g., "", "\n", "," +}; + +struct tool_id_analysis { + call_id_position pos = call_id_position::NONE; + + std::string prefix; // e.g., "[CALL_ID]" (marker before call ID value) + std::string suffix; // e.g., "" (marker after call ID value, before next section) +}; + +// ============================================================================ +// Parser build context (shared interface for build_parser methods) +// ============================================================================ + +struct analyze_content; + +struct parser_build_context { + common_chat_peg_builder & p; + const templates_params & inputs; + common_peg_parser reasoning_parser; + bool extracting_reasoning = false; + const analyze_content * content = nullptr; + + parser_build_context(common_chat_peg_builder & p, const templates_params & inputs); +}; + +// ============================================================================ +// Base class for analyzers with parser building +// ============================================================================ + +struct analyze_base { + virtual ~analyze_base() = default; + virtual common_peg_parser build_parser(parser_build_context & ctx) const = 0; + + protected: + const common_chat_template * tmpl = nullptr; + + analyze_base() = default; + explicit analyze_base(const common_chat_template & tmpl) : tmpl(&tmpl) {} +}; + +// ============================================================================ +// Reasoning analyzer +// ============================================================================ + +struct analyze_reasoning : analyze_base { + reasoning_mode mode = reasoning_mode::NONE; + + std::string start; // e.g., "", "[THINK]", "<|START_THINKING|>", "" + std::string end; // e.g., "", "[BEGIN FINAL RESPONSE]", "<|END_THINKING|>" + + analyze_reasoning() = default; + analyze_reasoning(const common_chat_template & tmpl, bool supports_tools); + + common_peg_parser build_parser(parser_build_context & ctx) const override; + + private: + // Look for reasoning markers in rendered content + void compare_reasoning_presence(); + + // Compare generation prompt with enable_thinking=true vs false + void compare_thinking_enabled(); + + // Check if reasoning is always possible or only in tool calls + void compare_reasoning_scope(); +}; + +// ============================================================================ +// Content analyzer +// ============================================================================ + +struct analyze_content : analyze_base { + content_mode mode = content_mode::PLAIN; + + std::string start; // e.g., "", ">>>all\n", "" + std::string end; // e.g., "", "" + + bool requires_nonnull_content = false; + + analyze_content() = default; + analyze_content(const common_chat_template & tmpl, const analyze_reasoning & reasoning); + + common_peg_parser build_parser(parser_build_context & ctx) const override; + + bool is_always_wrapped() const; + common_peg_parser build_optional_wrapped(parser_build_context & ctx) const; +}; + +// ============================================================================ +// Tool analyzer +// ============================================================================ + +struct analyze_tools : analyze_base { + tool_format_analysis format; + tool_function_analysis function; + tool_arguments_analysis arguments; + tool_id_analysis call_id; + + analyze_tools() = default; + analyze_tools(const common_chat_template & tmpl, + const jinja::caps & caps, + const analyze_reasoning & reasoning); + + common_peg_parser build_parser(parser_build_context & ctx) const override; + + private: + // Extract tool calling 'haystack' for further analysis and delegate further analysis based on format + void analyze_tool_calls(const analyze_reasoning & reasoning); + + // Analyze format based on position of function and argument name in needle + void analyze_tool_call_format(const std::string & haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle, + const analyze_reasoning & reasoning); + + // Analyze specifics of JSON native format (entire tool call is a JSON object) + void analyze_tool_call_format_json_native(const std::string & clean_haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle); + + // Analyze specifics of non-JSON native format (tags for function name or for function name and arguments) + void analyze_tool_call_format_non_json(const std::string & clean_haystack, + const std::string & fun_name_needle); + + // Check for and extract specific per-call markers for non-native-JSON templates with parallel call support + void check_per_call_markers(); + + // Extract function name markers + void extract_function_markers(); + + // Delegates to separate functions for: separator analysis, argument name analysis, argument value analysis + void analyze_arguments(); + + // Extract argument name markers + void extract_argument_name_markers(); + + // Extract argument value markers + void extract_argument_value_markers(); + + // Extract argument separator, if specified (eg. ......) + void extract_argument_separator(); + + // Extract argument wrapper markers, if present (eg. '......') + void extract_args_markers(); + + // Extract call ID markers, if present + void extract_call_id_markers(); + + // Per-format tool parser builders + common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const; + common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const; + common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const; +}; + +// ============================================================================ +// Main autoparser class +// ============================================================================ + +struct autoparser { + jinja::caps jinja_caps; + analyze_reasoning reasoning; + analyze_content content; + analyze_tools tools; + bool analysis_complete = false; + + // Preserved tokens for tokenizer (union of all non-empty markers) + std::vector preserved_tokens; + + autoparser() = default; + + // Run full differential analysis on a template + void analyze_template(const common_chat_template & tmpl); + + // Build the PEG parser for this template + common_peg_arena build_parser(const templates_params & inputs) const; + + private: + // Collect tokens from entire analysis to preserve + void collect_preserved_tokens(); +}; + +// ============================================================================ +// Parser generator +// ============================================================================ + +class peg_generator { + public: + static common_chat_params generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs); + + static common_chat_params generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs, + const autoparser & autoparser); +}; + +} // namespace autoparser + +enum segment_type { TEXT, MARKER }; + +inline std::ostream & operator<<(std::ostream & os, const segment_type & type) { + switch (type) { + case segment_type::TEXT: + return os << "TEXT"; + case segment_type::MARKER: + return os << "MARKER"; + default: + return os << "UNKNOWN"; + } +} + +struct segment { + segment_type type; + std::string value; + + segment(segment_type type, std::string value) : type(type), value(std::move(value)) {} + + bool operator==(const segment & other) const { + return type == other.type && value == other.value; + } + + bool operator!=(const segment & other) const { + return !(*this == other); + } +}; diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp new file mode 100644 index 000000000..4068340a5 --- /dev/null +++ b/common/chat-diff-analyzer.cpp @@ -0,0 +1,1330 @@ +#include "chat-auto-parser.h" +#include "chat-auto-parser-helpers.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "log.h" +#include "nlohmann/json.hpp" +#include "peg-parser.h" + +#include + +#define ANSI_RESET "\033[0m" +#define ANSI_PURPLE "\033[1m\x1b[38;5;126m" +#define ANSI_ORANGE "\033[1m\x1b[38;5;214m" +#define ANSI_RED "\033[1m\x1b[38;5;196m" + +using json = nlohmann::ordered_json; + +namespace autoparser { + +static const std::string FUN_FIRST = "FFF_FIRST_FUN_F"; +static const std::string FUN_SECOND = "SSS_SECOND_FUN_S"; +static const std::string ARG_FIRST = "AA_ARG_FST_AA"; +static const std::string ARG_SECOND = "BB_ARG_SND_BB"; +static const std::string USER_MSG = "U_USER_MSG Hello END_U"; +static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A"; +static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R"; + +static std::vector> workarounds( + { // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to + // support reasoning on them + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("content.split('')") != std::string::npos && + tmpl.src.find("reasoning_content") == std::string::npos && + analysis.reasoning.mode == reasoning_mode::NONE) { + analysis.reasoning.mode = reasoning_mode::FORCED_OPEN; + analysis.reasoning.start = ""; + analysis.reasoning.end = ""; + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: old Qwen/Deepseek thinking template]\n" ANSI_RESET); + } + }, + // Granite 3.3, with separate reasoning and content markers + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("Write your thoughts between and write your response between " + "") != std::string::npos) { + analysis.reasoning.mode = reasoning_mode::TAG_BASED; + analysis.reasoning.start = ""; + analysis.reasoning.end = ""; + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + analysis.content.mode = content_mode::WRAPPED_WITH_REASONING; + analysis.content.start = ""; + analysis.content.end = ""; + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: Granite 3.3]\n" ANSI_RESET); + } + }, + // Cohere Command R+ - content wrapped in <|CHATBOT_TOKEN|>...<|END_OF_TURN_TOKEN|> + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("<|CHATBOT_TOKEN|>") != std::string::npos && + tmpl.src.find("<|END_OF_TURN_TOKEN|>") != std::string::npos && analysis.content.start.empty()) { + analysis.content.mode = content_mode::ALWAYS_WRAPPED; + analysis.content.start = "<|CHATBOT_TOKEN|>"; + analysis.content.end = "<|END_OF_TURN_TOKEN|>"; + analysis.preserved_tokens.push_back("<|CHATBOT_TOKEN|>"); + analysis.preserved_tokens.push_back("<|END_OF_TURN_TOKEN|>"); + LOG_DBG(ANSI_ORANGE "[Patch: Cohere Command R+]\n" ANSI_RESET); + } + }, + // Functionary - no tool call section delimiter + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("set has_code_interpreter = tools | selectattr(\"type\", \"equalto\", " + "\"code_interpreter\") | list | length > 0") != std::string::npos) { + analysis.content.mode = content_mode::PLAIN; + analysis.content.end = ""; + analysis.tools.function.name_prefix = ""; + analysis.tools.format.section_start = ""; + analysis.tools.format.section_end = ""; + analysis.tools.format.per_call_start = ""); + analysis.preserved_tokens.push_back("<|eom_id|>"); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: Functionary 3.1]\n" ANSI_RESET); + } + }, + // DeepSeek-R1-Distill-Qwen + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find( + "{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>'") != + std::string::npos) { + analysis.tools.format.section_start = "<|tool▁calls▁begin|>"; + analysis.tools.format.section_end = "<|tool▁calls▁end|>"; + analysis.tools.format.per_call_start = "<|tool▁call▁begin|>function"; + analysis.tools.function.name_prefix = "<|tool▁sep|>"; + analysis.tools.format.per_call_end = "<|tool▁call▁end|>"; + analysis.tools.function.close = "```"; + } + } + }); + +// Common JSON structures +static json params_schema = { + { "type", "object" }, + { "properties", + { { ARG_FIRST, { { "type", "string" }, { "description", "First argument" } } }, + { ARG_SECOND, { { "type", "string" }, { "description", "Second argument" } } } } }, + { "required", json::array({}) } +}; + +static json tools = json::array({ + { { "type", "function" }, + { "function", + json{ { "name", FUN_FIRST }, { "description", "Test function foo" }, { "parameters", params_schema } } } }, + { { "type", "function" }, + { "function", + json{ { "name", FUN_SECOND }, { "description", "Test function bar" }, { "parameters", params_schema } } } } +}); + +static json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } +}; + +static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") { + return json{ + { "id", id }, + { "type", "function" }, + { "function", json{ { "name", name }, { "arguments", args } } } + }; +} + +static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001"); +static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001"); +static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001"); +static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001"); + +static json first_tool_call = + build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001"); +static json second_tool_call = + build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002"); +static json first_tool_call_alt_id = + build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999"); + +template +static std::string mode_to_str(T mode) { + std::ostringstream os; + os << mode; + return os.str(); +} + +void autoparser::analyze_template(const common_chat_template & tmpl) { + jinja_caps = tmpl.original_caps(); + reasoning = analyze_reasoning(tmpl, jinja_caps.supports_tool_calls); + content = analyze_content(tmpl, reasoning); + tools = analyze_tools(jinja_caps.supports_tool_calls ? analyze_tools(tmpl, jinja_caps, reasoning) : analyze_tools()); + collect_preserved_tokens(); + + for (auto & workaround : workarounds) { + workaround(tmpl, *this); + } + + LOG_DBG("\n--- Reasoning & Content Structure ---\n"); + LOG_DBG("reasoning_mode: %s\n", mode_to_str(reasoning.mode).c_str()); + LOG_DBG("reasoning_start: '%s'\n", reasoning.start.c_str()); + LOG_DBG("reasoning_end: '%s'\n", reasoning.end.c_str()); + LOG_DBG("content_mode: %s\n", mode_to_str(content.mode).c_str()); + LOG_DBG("content_start: '%s'\n", content.start.c_str()); + LOG_DBG("content_end: '%s'\n", content.end.c_str()); + + LOG_DBG("\n--- Tool Call Structure ---\n"); + LOG_DBG("tool_mode: %s\n", mode_to_str(tools.format.mode).c_str()); + LOG_DBG("supports_tools: %s\n", jinja_caps.supports_tools ? "true" : "false"); + LOG_DBG("supports_parallel_calls: %s\n", jinja_caps.supports_parallel_tool_calls ? "true" : "false"); + LOG_DBG("tool_section_start: '%s'\n", tools.format.section_start.c_str()); + LOG_DBG("tool_section_end: '%s'\n", tools.format.section_end.c_str()); + LOG_DBG("per_call_start: '%s'\n", tools.format.per_call_start.c_str()); + LOG_DBG("per_call_end: '%s'\n", tools.format.per_call_end.c_str()); + LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str()); + LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str()); + LOG_DBG("func_close: '%s'\n", tools.function.close.c_str()); + LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false"); + LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str()); + LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str()); + LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str()); + LOG_DBG("arg_value_suffix: '%s'\n", tools.arguments.value_suffix.c_str()); + LOG_DBG("name_field: '%s'\n", tools.format.name_field.c_str()); + LOG_DBG("args_field: '%s'\n", tools.format.args_field.c_str()); + LOG_DBG("id_field: '%s'\n", tools.format.id_field.c_str()); + LOG_DBG("gen_id_field: '%s'\n", tools.format.gen_id_field.c_str()); + LOG_DBG("parameter_order: '%s'\n", std::accumulate(tools.format.parameter_order.begin(), tools.format.parameter_order.end(), + std::string(""), [] (const std::string & a, const std::string & b) { return a.empty() ? b : a + ", " + b; } + ).c_str()); + + LOG_DBG(ANSI_PURPLE "=== Differential analysis complete ===\n" ANSI_RESET); + analysis_complete = true; +} + +void autoparser::collect_preserved_tokens() { + auto add_token = [this](const std::string & org_token) { + std::string token = trim_whitespace(org_token); + if (!token.empty()) { + // Avoid duplicates + if (std::find(preserved_tokens.begin(), preserved_tokens.end(), token) == preserved_tokens.end()) { + preserved_tokens.push_back(token); + } + } + }; + + add_token(reasoning.start); + add_token(reasoning.end); + add_token(content.start); + add_token(content.end); + add_token(tools.format.section_start); + add_token(tools.format.section_end); + add_token(tools.format.per_call_start); + add_token(tools.format.per_call_end); + add_token(tools.function.name_prefix); + add_token(tools.function.name_suffix); + add_token(tools.function.close); + add_token(tools.arguments.start); + add_token(tools.arguments.end); + add_token(tools.arguments.name_prefix); + add_token(tools.arguments.name_suffix); + add_token(tools.arguments.separator); + add_token(tools.arguments.value_prefix); + add_token(tools.arguments.value_suffix); + add_token(tools.call_id.prefix); + add_token(tools.call_id.suffix); +} + +analyze_reasoning::analyze_reasoning(const common_chat_template & tmpl, bool supports_tools) + : analyze_base(tmpl) { + LOG_DBG(ANSI_PURPLE "=== Starting differential analysis ===\n" ANSI_RESET); + LOG_DBG(ANSI_ORANGE "Phase 1: Reasoning analysis\n" ANSI_RESET); + + compare_reasoning_presence(); + compare_thinking_enabled(); + if (supports_tools) { + compare_reasoning_scope(); + } +} + +void analyze_reasoning::compare_reasoning_presence() { + json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } + }; + + json assistant_no_reasoning = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json assistant_with_reasoning = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG }, + { "reasoning_content", THINKING_CONTENT } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_reasoning }); + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_with_reasoning }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed, skipping reasoning detection\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + const std::string reasoning_content = THINKING_CONTENT; + + if (!diff.right.empty() && diff.right.find(reasoning_content) != std::string::npos) { + auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest()); + }); + auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + }); + // try the more aggressive parse first, if it fails, fall back to the delimiter one + auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); + if (!result.result.success()) { + result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); + } + if (result.result.success()) { + if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { + if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close + mode = reasoning_mode::TAG_BASED; + } else { + mode = reasoning_mode::FORCED_CLOSED; + } + start = trim_whitespace(result.tags["pre"]); + end = result.tags["post"]; + } else if (!result.tags["post"].empty()) { + mode = reasoning_mode::DELIMITER; + end = result.tags["post"]; + } + } + } +} + +void analyze_reasoning::compare_thinking_enabled() { + json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } + }; + + template_params params; + params.messages = json::array({ user_msg }); + params.add_generation_prompt = true; + params.enable_thinking = false; + + auto comparison = compare_variants(*tmpl, params, [&](template_params & p) { p.enable_thinking = true; }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET , __func__); + return; + } + + const auto & diff = comparison->diff; + + std::string left_trimmed = trim_whitespace(diff.left); + + if (left_trimmed.empty() && !diff.right.empty()) { + std::string right_trimmed = trim_whitespace(diff.right); + + if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { + if (start.empty()) { + start = right_trimmed; + mode = reasoning_mode::FORCED_OPEN; + } + } + } + + if (start.empty() && !end.empty()) { + mode = reasoning_mode::DELIMITER; + } + + // Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers, + // but enable_thinking=true produces only the start marker + if (!comparison->output_A.empty() && !comparison->output_B.empty()) { + auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.literal(start) + p.space() + p.literal(end) + p.rest(); + }); + auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest(); + }); + if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() && + parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) { + mode = reasoning_mode::FORCED_CLOSED; + } else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier + auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A); + if (result.result.success()) { + start = result.tags["pre"]; + mode = reasoning_mode::FORCED_CLOSED; + } + } + } + + if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close" + if (!diff.left.empty() && !diff.right.empty()) { + auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left)); + auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right)); + if (seg_A.size() == 1 && seg_B.size() == 1) { + mode = reasoning_mode::FORCED_CLOSED; + start = seg_B[0].value; + end = seg_A[0].value; + } + } + } +} + +void analyze_reasoning::compare_reasoning_scope() { + json assistant_reasoning_content = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG }, + { "reasoning_content", THINKING_CONTENT } + }; + + json assistant_reasoning_tools = json{ + { "role", "assistant" }, + { "content", nullptr }, + { "reasoning_content", THINKING_CONTENT }, + { "tool_calls", + json::array({ build_tool_call(FUN_FIRST, json{ { ARG_FIRST, "VVVV" }, { ARG_SECOND, "XXXX" } }) }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_reasoning_content }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_reasoning_tools }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + std::string reasoning_content = THINKING_CONTENT; + + // Check if reasoning only appears in variant B (with tools) + bool reasoning_in_A = comparison->output_A.find(reasoning_content) != std::string::npos; + bool reasoning_in_B = comparison->output_B.find(reasoning_content) != std::string::npos; + + if (!reasoning_in_A && reasoning_in_B) { + mode = reasoning_mode::TOOLS_ONLY; + LOG_DBG(ANSI_ORANGE "%s: Detected TOOLS_ONLY reasoning mode\n" ANSI_RESET, __func__); + + auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())); + }); + auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); + if (result.result.success()) { + start = result.tags["pre"]; + end = result.tags["post"]; + } else { + auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space()))); + }); + result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); + if (result.result.success()) { + end = result.tags["post"]; + } else { + LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); + mode = reasoning_mode::NONE; + } + } + } +} + +analyze_content::analyze_content(const common_chat_template & tmpl, const analyze_reasoning & reasoning) + : analyze_base(tmpl) { + LOG_DBG(ANSI_ORANGE "Phase 2: Content analysis\n" ANSI_RESET); + + json assistant_content_only = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json assistant_with_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ build_tool_call("test_func", json{ { "arg1", "value1" } }) }) } + }; + + json assistant_with_reasoning = json{ + { "role", "assistant" }, + { "content", "" }, + { "reasoning_content", THINKING_CONTENT } + }; + + template_params params_content_only; + params_content_only.messages = json::array({ user_msg, assistant_content_only }); + params_content_only.add_generation_prompt = false; + params_content_only.enable_thinking = true; + params_content_only.tools = tools; + + auto comparison_with_tools = compare_variants(tmpl, params_content_only, [&](template_params & p) { + p.messages = json::array({ user_msg, assistant_with_tools }); + }); + + auto comparison_with_reasoning = compare_variants(tmpl, params_content_only, [&](template_params & p) { + p.messages = json::array({ user_msg, assistant_with_reasoning }); + }); + + if (!comparison_with_tools || !comparison_with_reasoning) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + } + + const auto & diff_tools = comparison_with_tools->diff; + const auto & diff_reasoning = comparison_with_reasoning->diff; + + std::string response = ASSISTANT_MSG; + + bool found_plain_content = false; + if (trim_whitespace(diff_tools.left) == response) { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.space() + diff_reasoning.left + p.space() + p.optional(p.marker()) + p.space() + p.end(); + }); + if (parser.parse_and_extract(diff_reasoning.left).result.success()) { + // We only have the content text in the diff (possibly with a stray EOG marker), so no markers + mode = content_mode::PLAIN; + found_plain_content = true; + } else if (reasoning.mode != reasoning_mode::NONE && !reasoning.end.empty()) { + auto post_reasoning_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.literal(reasoning.end) + p.space() + p.literal(response); + }); + if (post_reasoning_parser.parse_anywhere_and_extract(diff_reasoning.left).result.success()) { + mode = content_mode::PLAIN; + found_plain_content = true; + } + } + } + if (!found_plain_content) { + std::string rdiff = diff_reasoning.left; + if (!reasoning.end.empty() && rdiff.find(reasoning.end) != std::string::npos) { + rdiff = rdiff.substr(rdiff.find(reasoning.end) + reasoning.end.length()); + } + // Take the more promising diff + std::string pure_content = rdiff.length() > diff_tools.left.length() ? rdiff : diff_tools.left; + auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.marker()) + p.space() + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + }); + auto result = parser_wrapped.parse_anywhere_and_extract(pure_content); + start = result.tags["pre"]; + end = result.tags["post"]; + // TODO: WRAPPED_WITH_REASONING + } + + // Determine content mode + if (!start.empty() || !end.empty()) { + mode = content_mode::ALWAYS_WRAPPED; + // TODO: END_DELIMITED content mode - delimited at end but not at start? + } +} + +bool analyze_content::is_always_wrapped() const { + return mode == content_mode::ALWAYS_WRAPPED && !start.empty() && !end.empty(); +} + +analyze_tools::analyze_tools(const common_chat_template & tmpl, + const jinja::caps & caps, + const analyze_reasoning & reasoning) + : analyze_base(tmpl) { + LOG_DBG(ANSI_ORANGE "Phase 3: Tool call analysis\n" ANSI_RESET); + + analyze_tool_calls(reasoning); + + if (format.mode != tool_format::NONE && format.mode != tool_format::JSON_NATIVE) { + if (caps.supports_parallel_tool_calls) { + check_per_call_markers(); + } + extract_function_markers(); + if (format.mode == tool_format::TAG_WITH_TAGGED) { + analyze_arguments(); + } + extract_argument_separator(); + extract_args_markers(); + extract_call_id_markers(); + } +} + +void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) { + json assistant_no_tools = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json assistant_with_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_tools }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_with_tools }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + std::string tool_section = diff.right; + + if (tool_section.empty()) { + return; + } + + analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning); +} + +void analyze_tools::analyze_tool_call_format(const std::string & haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle, + const analyze_reasoning & reasoning) { + if (fun_name_needle.empty() || arg_name_needle.empty() || haystack.empty()) { + return; + } + + enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES }; + + auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({ + p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")), + p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) }); + }); + auto result = parser.parse_anywhere_and_extract(haystack); + if (!result.result.success()) { + return json_quote_style::NONE; + } + return result.tags.count("sq") && !result.tags["sq"].empty() + ? json_quote_style::SINGLE_QUOTES + : json_quote_style::DOUBLE_QUOTES; + }; + + auto fun_quote = in_json_haystack(fun_name_needle); + auto arg_quote = in_json_haystack(arg_name_needle); + + if (fun_quote != json_quote_style::NONE) { + // no need to check further, we're in JSON land + format.mode = tool_format::JSON_NATIVE; + format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES); + } else if (arg_quote != json_quote_style::NONE) { + format.mode = tool_format::TAG_WITH_JSON; + format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES); + } else { + format.mode = tool_format::TAG_WITH_TAGGED; + } + + // first, remove any reasoning markers + std::string clean_haystack = haystack; + if (!reasoning.start.empty()) { + auto pos = haystack.find(reasoning.start); + if (pos != std::string::npos) { + clean_haystack = haystack.substr(0, pos) + haystack.substr(pos + reasoning.start.length()); + } + } + if (!reasoning.end.empty()) { + auto pos = clean_haystack.find(reasoning.end); + if (pos != std::string::npos) { + clean_haystack = clean_haystack.substr(0, pos) + clean_haystack.substr(pos + reasoning.end.length()); + } + } + + if (format.mode == tool_format::JSON_NATIVE) { + analyze_tool_call_format_json_native(clean_haystack, fun_name_needle, arg_name_needle); + } else { + analyze_tool_call_format_non_json(clean_haystack, fun_name_needle); + } + // always relax whitespace requirements on ending markers since they don't influence content + format.section_end = trim_whitespace(format.section_end); + format.per_call_end = trim_whitespace(format.per_call_end); +} + +void analyze_tools::analyze_tool_call_format_json_native(const std::string & clean_haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle) { + // we might not have the typical OpenAI tool calling structure + int json_start = clean_haystack.find_first_of('{'); + int json_end = clean_haystack.find_last_of('}'); + std::string cut = clean_haystack.substr(json_start, json_end - json_start + 1); + json call_struct = json::parse(cut); + auto register_field = [&](const std::string & prefix, const nlohmann::detail::iteration_proxy_value & subel) { + if (subel.value().is_string() && std::string(subel.value()).find("call0000") != std::string::npos) { + format.id_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } else if (subel.value().is_string() && std::string(subel.value()) == fun_name_needle) { + format.name_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } else if (subel.value().dump().find(arg_name_needle) != + std::string::npos) { // handle both string and JSON obj variants + format.args_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } else if (subel.key().find("id") != std::string::npos) { + // heuristics for generated id field + format.gen_id_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } + }; + for (const auto & el : call_struct.items()) { + if (el.key() == fun_name_needle) { + format.fun_name_is_key = true; + // When function name is the key, there's no name field and args are direct + format.name_field.clear(); + format.args_field.clear(); + // Don't register this element - the function name IS the key, not a field + } else { + if (el.value().is_object() && + el.value().dump().find(arg_name_needle) == std::string::npos) { // not the args object + format.function_field = el.key(); + for (const auto & subel : el.value().items()) { + register_field(el.key(), subel); + } + } + // Register this element as a potential field + register_field("", el); + } + } + auto array_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.literal("[") + p.space()) + p.literal(cut) + p.tag("post", p.space() + p.literal("]")); + }); + + auto ar_parse_res = array_parser.parse_anywhere_and_extract(clean_haystack); + if (ar_parse_res.result.success()) { + format.tools_array_wrapped = true; + json_start -= ar_parse_res.tags["pre"].length(); + json_end += ar_parse_res.tags["post"].length(); + } + json_end++; // we want to move past the closing char for end marker extraction + + std::vector> located_params; + if (!format.name_field.empty()) { + located_params.push_back({ clean_haystack.find(format.name_field), format.name_field }); + } + if (!format.args_field.empty()) { + located_params.push_back({ clean_haystack.find(format.args_field), format.args_field }); + } + if (!format.id_field.empty()) { + located_params.push_back({ clean_haystack.find(format.id_field), format.id_field }); + } + if (!format.gen_id_field.empty()) { + located_params.push_back({ clean_haystack.find(format.gen_id_field), format.gen_id_field }); + } + std::sort(located_params.begin(), located_params.end()); + for (auto & pair : located_params) { + format.parameter_order.push_back(pair.second); + } + // we can immediately extract tool calling markers too + format.section_start = trim_leading_whitespace(clean_haystack.substr(0, json_start)); + format.section_end = trim_whitespace(clean_haystack.substr(json_end)); + // When tools_array_wrapped is true, the closing bracket is part of the array structure, + // not a separate section end marker. Clear tool_section_end to avoid duplicate brackets. + if (format.tools_array_wrapped && format.section_end == "]") { + format.section_end.clear(); + } +} + +void analyze_tools::analyze_tool_call_format_non_json(const std::string & clean_haystack, + const std::string & fun_name_needle) { + // first, let's find out if the function is inside a tag or standalone + auto fun_marker_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("fun_marker", p.choice({ + p.tag("fun_pre", p.literal("<") + p.until_one_of({ ">", fun_name_needle })) + p.literal(fun_name_needle) + + p.tag("fun_post", p.negate(p.space() + p.literal("<")) + p.until(">") + p.literal(">")) + p.space(), + p.tag("fun_pre", p.literal("[") + p.until_one_of({ "]", fun_name_needle })) + p.literal(fun_name_needle) + + p.tag("fun_post", p.negate(p.space() + p.literal("[") + p.until("]") + p.literal("]")) + p.space()) })); + }); + auto fun_res = fun_marker_parser.parse_anywhere_and_extract(clean_haystack); + std::string fun_marker = fun_name_needle; + if (fun_res.result.success()) { + fun_marker = fun_res.tags["fun_marker"]; + } + // now, consume up to two markers, then treat everything up to the function marker as function name prefix + auto per_tool_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("sec_start", p.marker() + p.space()) + p.tag("call_start", p.marker() + p.space()) + + p.tag("fun_pre", p.until(fun_marker)) + fun_marker + p.tag("rest", p.rest()); + }); + auto section_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("sec_start", p.marker() + p.space()) + fun_marker + p.tag("rest", p.rest()); + }); + auto result = per_tool_parser.parse_anywhere_and_extract(clean_haystack); + tagged_parse_result result_end; + if (result.result.success()) { + auto double_closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("call_end", p.marker() + p.space()) + p.tag("sec_end", p.marker() + p.space()) + p.end(); + }); + result_end = double_closer_parser.parse_anywhere_and_extract(result.tags["rest"]); + function.name_prefix = fun_res.tags["fun_pre"] + function.name_prefix; + } else { + result = section_parser.parse_anywhere_and_extract(clean_haystack); + auto single_closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("sec_end", p.marker() + p.space()) + p.end(); + }); + result_end = single_closer_parser.parse_anywhere_and_extract(result.tags["rest"]); + } + format.per_call_start = result.tags["call_start"]; + format.per_call_end = result_end.tags["call_end"]; + format.section_start = result.tags["sec_start"]; + format.section_end = result_end.tags["sec_end"]; +} + +void analyze_tools::check_per_call_markers() { + json assistant_one_tool = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_two_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call, second_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_one_tool }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto one_vs_two = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_tools }); }); + + if (!one_vs_two) { + LOG_DBG(ANSI_ORANGE "%s: Generating double tool call comparison failed\n" ANSI_RESET, __func__); + return; + } + + diff_split filter_common_call_part = calculate_diff_split(one_vs_two->diff.suffix, one_vs_two->diff.right); + + std::string second_tool_content = trim_leading_whitespace(filter_common_call_part.right); + if (!format.section_start.empty() && + second_tool_content.find(format.section_start) == 0) { + format.per_call_start = format.section_start; + format.per_call_end = format.section_end; + format.section_start.clear(); + format.section_end.clear(); + } +} + +void analyze_tools::extract_function_markers() { + json assistant_nocall = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG }, + }; + + json assistant_foofoo = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_barbar = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ second_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_foofoo }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_barbar }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (diff.left.find(FUN_FIRST) != std::string::npos && diff.right.find(FUN_SECOND) != std::string::npos) { + std::string prefix_marker; + if (!format.per_call_start.empty()) { + prefix_marker = format.per_call_start; + } else { + prefix_marker = format.section_start; + } + if (!prefix_marker.empty() && diff.prefix.rfind(prefix_marker) != std::string::npos) { + function.name_prefix = + diff.prefix.substr(diff.prefix.rfind(prefix_marker) + prefix_marker.size()); + } + + // Extract name prefix/suffix from diff.left (stop at the next marker boundary) + auto name_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.until(FUN_FIRST)) + p.literal(FUN_FIRST) + + p.tag("post", p.zero_or_more(p.negate(p.marker()) + p.any())); + }); + auto name_result = name_parser.parse_and_extract(diff.left); + if (name_result.result.success()) { + function.name_prefix += name_result.tags["pre"]; + function.name_suffix = name_result.tags["post"]; + } + + // Extend name_suffix with content from diff.suffix before args begin + if (format.mode == tool_format::TAG_WITH_JSON) { + // For JSON: name_suffix extends to the first non-marker { or [, including any + // markers along the way. Only applies if there's at least one marker after + // the JSON content (matching the original "stop < seg_suf.size() - 1" guard). + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + auto non_json = p.marker() | (p.negate(p.literal("{")) + p.negate(p.literal("[")) + p.any()); + auto after_json = p.zero_or_more(p.negate(p.marker()) + p.any()) + p.marker(); + return p.tag("ext", p.zero_or_more(non_json)) + after_json; + }); + auto suf_result = suffix_parser.parse_and_extract(diff.suffix); + if (suf_result.result.success()) { + function.name_suffix += suf_result.tags["ext"]; + } + } else { + // For tagged: name_suffix extends to the first marker (arg marker) + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("ext", p.zero_or_more(p.negate(p.marker()) + p.any())); + }); + auto suf_result = suffix_parser.parse_and_extract(diff.suffix); + if (suf_result.result.success()) { + function.name_suffix += suf_result.tags["ext"]; + } + } + + // Extract the closer (between last arg and call/section end marker) + std::string suffix_marker; + if (!format.per_call_end.empty()) { + suffix_marker = format.per_call_end; + } else { + suffix_marker = format.section_end; + } + std::string closer_suffix; + if (suffix_marker.empty()) { + // we'll have to rely on an extra diff with no-calls version + auto notool_comp = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_nocall }); }); + auto nt_diff = notool_comp->diff; + closer_suffix = nt_diff.left.substr(nt_diff.left.find("YYYY") + 4); + } else { + closer_suffix = diff.suffix.substr(0, diff.suffix.find(suffix_marker)); + } + if (!closer_suffix.empty()) { + if (format.mode == tool_format::TAG_WITH_TAGGED) { + // After last arg value, skip the closing arg marker, rest is closer + auto closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.until("YYYY") + p.literal("YYYY") + p.space() + + p.marker() + p.space() + + p.tag("close", p.rest()); + }); + auto close_result = closer_parser.parse_and_extract(closer_suffix); + if (close_result.result.success()) { + function.close = close_result.tags["close"]; + } + } else if (format.mode == tool_format::TAG_WITH_JSON) { + // After last arg value, find end of JSON args, rest is closer + auto closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.until("YYYY") + p.literal("YYYY") + p.tag("post_val", p.rest()); + }); + auto close_result = closer_parser.parse_and_extract(closer_suffix); + if (close_result.result.success()) { + const auto & post = close_result.tags["post_val"]; + size_t pos = post.find_last_of("}]"); + if (pos != std::string::npos && pos < post.size() - 1) { + function.close = trim_leading_whitespace(post.substr(pos + 1)); + } + } + } + } + function.close = trim_leading_whitespace(function.close); + } +} + +void analyze_tools::analyze_arguments() { + LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET); + + extract_argument_name_markers(); + extract_argument_value_markers(); +} + +void analyze_tools::extract_argument_name_markers() { + json assistant_first_arg = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + json assistant_second_arg = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_other_arg }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_first_arg }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_second_arg }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (!diff.left.empty() && !diff.right.empty()) { + // Parse both sides to find ARG_FIRST/ARG_SECOND and extract the surrounding structure + auto left_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("pre", p.until(ARG_FIRST)) + p.literal(ARG_FIRST) + + p.tag("suffix", p.until_one_of({"\"", "X"})); + }); + auto right_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("pre", p.until(ARG_SECOND)) + p.literal(ARG_SECOND) + + p.tag("suffix", p.until_one_of({"\"", "Y"})); + }); + auto left_result = left_parser.parse_anywhere_and_extract(diff.left); + auto right_result = right_parser.parse_anywhere_and_extract(diff.right); + + if (left_result.result.success() && right_result.result.success() && + !left_result.tags["pre"].empty() && + left_result.tags["pre"] == right_result.tags["pre"] && + left_result.tags["suffix"] == right_result.tags["suffix"]) { + // Name is inside a structure (e.g., JSON key): prefix is the shared wrapper + arguments.name_prefix = trim_whitespace(left_result.tags["pre"]); + arguments.name_suffix = trim_leading_whitespace(left_result.tags["suffix"]); + } else if (diff.left.substr(0, ARG_FIRST.length()) == ARG_FIRST && diff.right.substr(0, ARG_SECOND.length()) == ARG_SECOND) { + // Name is directly in the diff: prefix comes from last marker in diff.prefix + auto pre_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + auto last_marker = p.marker() + p.zero_or_more(p.negate(p.marker()) + p.any()) + p.end(); + return p.zero_or_more(p.negate(last_marker) + p.any()) + p.tag("name_prefix", last_marker); + }); + auto pre_result = pre_parser.parse_and_extract(diff.prefix); + arguments.name_prefix = pre_result.result.success() + ? pre_result.tags["name_prefix"] : diff.prefix; + + // Suffix extends from after ARG_FIRST to the first marker (+ optional whitespace). + // The marker could be in diff.left itself or in diff.suffix, so we concatenate. + std::string after_first = diff.left.substr(ARG_FIRST.length()) + diff.suffix; + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("suffix", p.zero_or_more(p.negate(p.marker()) + p.any()) + + p.marker() + p.space()); + }); + auto suf_result = suffix_parser.parse_anywhere_and_extract(after_first); + if (suf_result.result.success()) { + arguments.name_suffix = suf_result.tags["suffix"]; + } + } + } +} + +void analyze_tools::extract_argument_value_markers() { + json assistant_val_X = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + json assistant_val_Y = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg_other_val }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_val_X }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_val_Y }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (diff.left == "XXXX" && diff.right == "YYYY") { + std::string arg_name_ending = ARG_FIRST + arguments.name_suffix; + std::string prefix = diff.prefix; + if (prefix.rfind(arg_name_ending) != std::string::npos) { + prefix = prefix.substr(prefix.rfind(arg_name_ending) + arg_name_ending.size()); + } + if (!prefix.empty()) { + // Find the last marker + any trailing non-marker text to end + auto prefix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + auto last_marker = p.marker() + p.zero_or_more(p.negate(p.marker()) + p.any()) + p.end(); + return p.zero_or_more(p.negate(last_marker) + p.any()) + p.tag("val_prefix", last_marker); + }); + auto pre_result = prefix_parser.parse_and_extract(prefix); + arguments.value_prefix = pre_result.result.success() ? pre_result.tags["val_prefix"] : prefix; + } + + std::string value_suffix = diff.suffix; + if (!function.close.empty()) { + size_t func_close_pos = value_suffix.find(function.close); + if (func_close_pos != std::string::npos) { + value_suffix = value_suffix.substr(0, func_close_pos); + } + } else if (!format.per_call_end.empty() || !format.section_end.empty()) { + std::string end_marker = + !format.per_call_end.empty() ? format.per_call_end : format.section_end; + size_t end_marker_pos = value_suffix.find(end_marker); + if (end_marker_pos != std::string::npos) { + value_suffix = value_suffix.substr(0, end_marker_pos); + } + } + value_suffix = trim_leading_whitespace(value_suffix); + if (!value_suffix.empty()) { + arguments.value_suffix = value_suffix; + } + } +} + +void analyze_tools::extract_argument_separator() { + json assistant_one_arg = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + json assistant_two_args = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_one_arg }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_args }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (!diff.right.empty()) { + std::string separator = until_common_prefix(diff.right, ARG_FIRST, ARG_SECOND); + arguments.separator = separator; + } +} + +void analyze_tools::extract_args_markers() { + json assistant_no_args = json{ + { "role", "assistant"}, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_zero_args }) } + }; + + json assistant_with_args = json{ + { "role", "assistant"}, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_args }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_with_args }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (format.mode != tool_format::JSON_NATIVE) { + std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; + std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end; + // these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones + size_t prefix_pos = prefix_marker.empty() ? 0 : diff.prefix.rfind(prefix_marker); + size_t suffix_pos = suffix_marker.empty() ? diff.suffix.size() : diff.suffix.find(suffix_marker); + if (prefix_pos == std::string::npos) { + prefix_pos = 0; + } + if (suffix_pos == std::string::npos) { + suffix_pos = diff.suffix.size(); + } + std::string prefix_cut = diff.prefix.substr(prefix_pos + prefix_marker.size()); + std::string suffix_cut = diff.suffix.substr(0, suffix_pos); + std::string args_start = until_common_prefix(prefix_cut, "{}", "{\"first\":"); + std::string args_end = after_common_suffix(suffix_cut, "{}", "\"XXXX\"}"); + + if (!args_start.empty() || !args_end.empty()) { + size_t find_fun = args_start.find(FUN_FIRST); + if (find_fun != std::string::npos) { + args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size()); + } + arguments.start = args_start; + arguments.end = args_end; + } + } +} + +void analyze_tools::extract_call_id_markers() { + json assistant_id1 = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_id2 = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_alt_id }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_id1 }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_id2 }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed for call_id detection\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (diff.left.empty() && diff.right.empty()) { + return; + } + + std::string id_value_1 = "call00001"; + std::string id_value_2 = "call99999"; + + size_t common_id_prefix_len = 0; + for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) { + if (id_value_1[i] == id_value_2[i]) { + common_id_prefix_len++; + } else { + break; + } + } + std::string common_id_part = id_value_1.substr(0, common_id_prefix_len); + + // Check if the function name is in the prefix (normal case: BETWEEN_FUNC_AND_ARGS or POST_ARGS) + // or in the suffix (call_id is PRE_FUNC_NAME) + std::string func_name = FUN_FIRST; + size_t func_name_in_prefix = diff.prefix.rfind(func_name); + size_t func_name_in_suffix = diff.suffix.find(func_name); + + // Helper: find the last marker in a string (returns just the marker, not trailing text) + auto find_last_marker = [](const std::string & str) -> std::string { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + auto last = p.marker() + p.zero_or_more(p.negate(p.marker()) + p.any()) + p.end(); + return p.zero_or_more(p.negate(last) + p.any()) + p.tag("m", p.marker()); + }); + auto res = parser.parse_anywhere_and_extract(str); + return res.result.success() ? res.tags["m"] : ""; + }; + + // Helper: find the first marker in a string + auto find_first_marker = [](const std::string & str) -> std::string { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("m", p.marker()); + }); + auto res = parser.parse_anywhere_and_extract(str); + return res.result.success() ? res.tags["m"] : ""; + }; + + if (func_name_in_prefix != std::string::npos && func_name_in_suffix == std::string::npos) { + // Function name is only in prefix - call_id is BETWEEN_FUNC_AND_ARGS or POST_ARGS + // Check if args indicator "{" is in prefix or suffix + size_t args_in_prefix = diff.prefix.find('{', func_name_in_prefix); + size_t args_in_suffix = diff.suffix.find('{'); + + if (args_in_suffix != std::string::npos && + (args_in_prefix == std::string::npos || args_in_prefix > diff.prefix.length())) { + // Args are in suffix, so call_id is BETWEEN_FUNC_AND_ARGS + call_id.pos = call_id_position::BETWEEN_FUNC_AND_ARGS; + + // Find call_id_prefix: marker immediately preceding common_id_part (no intervening markers) + std::string after_func = diff.prefix.substr(func_name_in_prefix + func_name.length()); + auto id_prefix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("prefix", p.marker()) + + p.zero_or_more(p.negate(p.marker()) + p.negate(p.literal(common_id_part)) + p.any()) + + p.literal(common_id_part); + }); + auto id_res = id_prefix_parser.parse_anywhere_and_extract(after_func); + if (id_res.result.success()) { + call_id.prefix = id_res.tags["prefix"]; + } else { + // Fallback: use the last marker in after_func + call_id.prefix = find_last_marker(after_func); + } + + // Extract call_id_suffix: the first marker in the suffix before args "{" + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.zero_or_more(p.negate(p.marker()) + p.negate(p.literal("{")) + p.any()) + + p.tag("suffix", p.marker()); + }); + auto suf_res = suffix_parser.parse_anywhere_and_extract(diff.suffix); + if (suf_res.result.success()) { + call_id.suffix = suf_res.tags["suffix"]; + } + } else if (args_in_prefix != std::string::npos) { + // Args are in prefix, so call_id is POST_ARGS + call_id.pos = call_id_position::POST_ARGS; + + // Extract last marker between args closing brace and the ID + std::string after_args = diff.prefix.substr(args_in_prefix); + size_t closing_brace = after_args.rfind('}'); + if (closing_brace != std::string::npos) { + std::string between_args_and_id = after_args.substr(closing_brace + 1); + call_id.prefix = find_last_marker(between_args_and_id); + } + + // call_id_suffix: first marker in diff.suffix + call_id.suffix = find_first_marker(diff.suffix); + } + } else if (func_name_in_suffix != std::string::npos && func_name_in_prefix == std::string::npos) { + // Function name is only in suffix - call_id is PRE_FUNC_NAME + call_id.pos = call_id_position::PRE_FUNC_NAME; + + // call_id_prefix: last marker in diff.prefix + call_id.prefix = find_last_marker(diff.prefix); + + // call_id_suffix: first marker in the portion of diff.suffix before func_name + std::string before_func = diff.suffix.substr(0, func_name_in_suffix); + call_id.suffix = find_first_marker(before_func); + } + + // When call_id is detected, per_call_end may have been incorrectly set to include + // the call_id_suffix and sample args. Clear it if it starts with call_id_suffix. + if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() && + format.per_call_end.find(call_id.suffix) == 0) { + format.per_call_end.clear(); + } +} + +} // namespace autoparser diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp deleted file mode 100644 index ba359fdbf..000000000 --- a/common/chat-parser-xml-toolcall.cpp +++ /dev/null @@ -1,879 +0,0 @@ -#include "chat.h" -#include "chat-parser.h" -#include "common.h" -#include "json-partial.h" -#include "json-schema-to-grammar.h" -#include "log.h" -#include "regex-partial.h" - -using json = nlohmann::ordered_json; - -class xml_toolcall_syntax_exception : public std::runtime_error { - public: - xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {} -}; - -template -inline void sort_uniq(std::vector &vec) { - std::sort(vec.begin(), vec.end()); - vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); -} - -template -inline bool all_space(const T &str) { - return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); }); -} - -static size_t utf8_truncate_safe(const std::string_view s) { - size_t len = s.size(); - if (len == 0) return 0; - size_t i = len; - for (size_t back = 0; back < 4 && i > 0; ++back) { - --i; - unsigned char c = s[i]; - if ((c & 0x80) == 0) { - return len; - } else if ((c & 0xC0) == 0xC0) { - size_t expected_len = 0; - if ((c & 0xE0) == 0xC0) expected_len = 2; - else if ((c & 0xF0) == 0xE0) expected_len = 3; - else if ((c & 0xF8) == 0xF0) expected_len = 4; - else return i; - if (len - i >= expected_len) { - return len; - } else { - return i; - } - } - } - return len - std::min(len, size_t(3)); -} - -inline void utf8_truncate_safe_resize(std::string &s) { - s.resize(utf8_truncate_safe(s)); -} - -inline std::string_view utf8_truncate_safe_view(const std::string_view s) { - return s.substr(0, utf8_truncate_safe(s)); -} - -static std::optional try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) { - if (literal1.size() == 0) return builder.try_find_literal(literal2); - const auto saved_pos = builder.pos(); - while (auto res = builder.try_find_literal(literal1)) { - builder.consume_spaces(); - const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos()); - if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) { - if (res->prelude.size() != res->groups[0].begin - saved_pos) { - res->prelude = builder.str({saved_pos, res->groups[0].begin}); - } - builder.move_to(builder.pos() + match_len); - res->groups[0].end = builder.pos(); - GGML_ASSERT(res->groups[0].begin != res->groups[0].end); - return res; - } - builder.move_to(res->groups[0].begin + 1); - } - builder.move_to(saved_pos); - return std::nullopt; -} - -/** - * make a GBNF that accept any strings except those containing any of the forbidden strings. - */ -std::string make_gbnf_excluding(std::vector forbids) { - constexpr auto charclass_escape = [](unsigned char c) -> std::string { - if (c == '\\' || c == ']' || c == '^' || c == '-') { - std::string s = "\\"; - s.push_back((char)c); - return s; - } - if (isprint(c)) { - return std::string(1, (char)c); - } - char buf[16]; - snprintf(buf, 15, "\\x%02X", c); - return std::string(buf); - }; - constexpr auto build_expr = [charclass_escape](auto self, const std::vector& forbids, int l, int r, int depth) -> std::string { - std::vector>> children; - int i = l; - while (i < r) { - const std::string &s = forbids[i]; - if ((int)s.size() == depth) { - ++i; - continue; - } - unsigned char c = (unsigned char)s[depth]; - int j = i; - while (j < r && (int)forbids[j].size() > depth && - (unsigned char)forbids[j][depth] == c) { - ++j; - } - children.push_back({c, {i, j}}); - i = j; - } - std::vector alts; - if (!children.empty()) { - std::string cls; - for (auto &ch : children) cls += charclass_escape(ch.first); - alts.push_back(std::string("[^") + cls + "]"); - } - for (auto &ch : children) { - std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1); - if (!childExpr.empty()) { - std::string quoted_ch = "\""; - if (ch.first == '\\') quoted_ch += "\\\\"; - else if (ch.first == '"') quoted_ch += "\\\""; - else if (isprint(ch.first)) quoted_ch.push_back(ch.first); - else { - char buf[16]; - snprintf(buf, 15, "\\x%02X", ch.first); - quoted_ch += buf; - } - quoted_ch += "\""; - std::string branch = quoted_ch + std::string(" ") + childExpr; - alts.push_back(branch); - } - } - if (alts.empty()) return ""; - std::ostringstream oss; - oss << "( "; - for (size_t k = 0; k < alts.size(); ++k) { - if (k) oss << " | "; - oss << alts[k]; - } - oss << " )"; - return oss.str(); - }; - if (forbids.empty()) return "( . )*"; - sort(forbids.begin(), forbids.end()); - std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0); - if (expr.empty()) { - std::string cls; - for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]); - expr = std::string("( [^") + cls + "] )"; - } - if (forbids.size() == 1) - return expr + "*"; - else - return std::string("( ") + expr + " )*"; -} - -/** - * Build grammar for xml-style tool call - * form.scope_start and form.scope_end can be empty. - * Requires data.format for model-specific hacks. - */ -void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) { - GGML_ASSERT(!form.tool_start.empty()); - GGML_ASSERT(!form.tool_sep.empty()); - GGML_ASSERT(!form.key_start.empty()); - GGML_ASSERT(!form.val_end.empty()); - GGML_ASSERT(!form.tool_end.empty()); - - std::string key_val_sep = form.key_val_sep; - if (form.key_val_sep2) { - key_val_sep += "\n"; - key_val_sep += *form.key_val_sep2; - } - GGML_ASSERT(!key_val_sep.empty()); - - if (tools.is_array() && !tools.empty()) { - data.grammar = build_grammar([&](const common_grammar_builder &builder) { - auto string_arg_val = form.last_val_end ? - builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) : - builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end})); - - std::vector tool_rules; - for (const auto & tool : tools) { - if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { - LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str()); - continue; - } - const auto & function = tool.at("function"); - if (!function.contains("name") || !function.at("name").is_string()) { - LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str()); - continue; - } - if (!function.contains("parameters") || !function.at("parameters").is_object()) { - LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str()); - continue; - } - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - struct parameter_rule { - std::string symbol_name; - bool is_required; - }; - std::vector arg_rules; - if (!parameters.contains("properties") || !parameters.at("properties").is_object()) { - LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str()); - continue; - } else { - std::vector requiredParameters; - if (parameters.contains("required")) { - try { parameters.at("required").get_to(requiredParameters); } - catch (const std::runtime_error&) { - LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str()); - } - } - sort_uniq(requiredParameters); - for (const auto & [key, value] : parameters.at("properties").items()) { - std::string quoted_key = key; - bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key); - if (form.key_start.back() == '"' && key_val_sep[0] == '"') { - quoted_key = gbnf_format_literal(key); - quoted_key = quoted_key.substr(1, quoted_key.size() - 2); - } - arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key, - gbnf_format_literal(form.key_start) + " " + - gbnf_format_literal(quoted_key) + " " + - gbnf_format_literal(key_val_sep) + " " + - ((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ? - (form.raw_argval ? - string_arg_val : - "( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )" - ) : - builder.add_schema(name + "-arg-" + key, value) - ) - ), required}); - } - } - - auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end)); - decltype(next_arg_with_sep) next_arg = "\"\""; - for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) { - std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep; - next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ? - include_this_arg : "( " + include_this_arg + " ) | " + next_arg - ); - include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg; - next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ? - include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep - ); - } - - std::string quoted_name = name; - if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') { - quoted_name = gbnf_format_literal(name); - quoted_name = quoted_name.substr(1, quoted_name.size() - 2); - } - quoted_name = gbnf_format_literal(quoted_name); - // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name - if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) { - quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+"; - } - tool_rules.push_back(builder.add_rule(name + "-call", - gbnf_format_literal(form.tool_start) + " " + - quoted_name + " " + - gbnf_format_literal(form.tool_sep) + " " + - next_arg - )); - } - - auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | ")); - auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once); - auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end)); - auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end); - builder.add_rule("root", - (form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") + - tool_call_multiple_with_end + "?" + - (form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end)) - ); - }); - - // grammar trigger for tool call - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start }); - } -} - -/** - * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. - * Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser. - * form.scope_start, form.tool_sep and form.scope_end can be empty. - */ -inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) { - GGML_ASSERT(!form.tool_start.empty()); - GGML_ASSERT(!form.key_start.empty()); - GGML_ASSERT(!form.key_val_sep.empty()); - GGML_ASSERT(!form.val_end.empty()); - GGML_ASSERT(!form.tool_end.empty()); - - // Helper to choose return false or throw error - constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) { - LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str()); - if (recovery) { - builder.move_to(start_pos); - return false; - } else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output."); - }; - // Drop substring from needle to end from a JSON - constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") { - auto pos = json_str.rfind(needle); - if (pos == std::string::npos) { - return false; - } - for (auto i = pos + needle.size(); i < json_str.size(); ++i) { - unsigned char ch = static_cast(json_str[i]); - if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) { - return false; - } - } - if (pos != 0 && json_str[pos - 1] == '"') { - --pos; - } - json_str.resize(pos); - return true; - }; - // Helper to generate a partial argument JSON - constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) { - auto rest = builder.consume_rest(); - utf8_truncate_safe_resize(rest); - set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG"); - auto tool_str = arguments.dump(); - if (partial_json(tool_str)) { - if (builder.add_tool_call(function_name, "", tool_str)) { - return; - } - } - LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str()); - }; - // Helper to find a close (because there may be form.last_val_end or form.last_tool_end) - constexpr auto try_find_close = []( - common_chat_msg_parser & builder, - const std::string & end, - const std::optional & alt_end, - const std::string & end_next, - const std::optional & alt_end_next - ) { - auto saved_pos = builder.pos(); - auto tc = builder.try_find_literal(end); - auto val_end_size = end.size(); - if (alt_end) { - auto pos_1 = builder.pos(); - builder.move_to(saved_pos); - auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next); - if (alt_end_next) { - builder.move_to(saved_pos); - auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next); - if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) { - tc2 = tc3; - } - } - if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) { - tc = tc2; - tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size()); - builder.move_to(tc->groups[0].end); - val_end_size = alt_end->size(); - } else { - builder.move_to(pos_1); - } - } - return std::make_pair(val_end_size, tc); - }; - // Helper to find a val_end or last_val_end, returns matched pattern size - const auto try_find_val_end = [try_find_close, &builder, &form]() { - return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end); - }; - // Helper to find a tool_end or last_tool_end, returns matched pattern size - const auto try_find_tool_end = [try_find_close, &builder, &form]() { - return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt); - }; - - bool recovery = true; - const auto start_pos = builder.pos(); - if (!all_space(form.scope_start)) { - if (auto tc = builder.try_find_literal(form.scope_start)) { - if (all_space(tc->prelude)) { - if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin) - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start)); - } else { - builder.move_to(start_pos); - return false; - } - } else return false; - } - while (auto tc = builder.try_find_literal(form.tool_start)) { - if (!all_space(tc->prelude)) { - LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n", - gbnf_format_literal(form.tool_start).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - builder.move_to(tc->groups[0].begin - tc->prelude.size()); - break; - } - - // Find tool name - auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep); - if (!func_name) { - auto [sz, tc] = try_find_tool_end(); - func_name = tc; - } - if (!func_name) { - // Partial tool name not supported - throw common_chat_msg_partial_exception("incomplete tool_call"); - } - // If the model generate multiple tool call and the first tool call has no argument - if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) { - builder.move_to(func_name->groups[0].begin - func_name->prelude.size()); - auto [sz, tc] = try_find_tool_end(); - func_name = tc; - } - - // Parse tool name - builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end); - std::string function_name = string_strip(func_name->prelude); - // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name - if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) { - if (string_starts_with(function_name, "functions.")) { - static const std::regex re(":\\d+$"); - if (std::regex_search(function_name, re)) { - function_name = function_name.substr(10, function_name.rfind(":") - 10); - } - } - } - - // Argument JSON - json arguments = json::object(); - - // Helper to generate a partial argument JSON - const auto gen_partial_args = [&](auto set_partial_arg) { - gen_partial_json(set_partial_arg, arguments, builder, function_name); - }; - - // Parse all arg_key/arg_value pairs - while (auto tc = builder.try_find_literal(form.key_start)) { - if (!all_space(tc->prelude)) { - LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n", - gbnf_format_literal(form.key_start).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - builder.move_to(tc->groups[0].begin - tc->prelude.size()); - break; - } - if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) { - auto tool_call_arg = arguments.dump(); - if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') { - tool_call_arg.resize(tool_call_arg.size() - 1); - } - builder.add_tool_call(function_name, "", tool_call_arg); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start)); - } - - // Parse arg_key - auto key_res = builder.try_find_literal(form.key_val_sep); - if (!key_res) { - gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";}); - throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start)); - } - if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) { - gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";}); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep)); - } - auto &key = key_res->prelude; - recovery = false; - - // Parse arg_value - if (form.key_val_sep2) { - if (auto tc = builder.try_find_literal(*form.key_val_sep2)) { - if (!all_space(tc->prelude)) { - LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n", - gbnf_format_literal(tc->prelude).c_str(), - gbnf_format_literal(form.key_val_sep).c_str(), - gbnf_format_literal(*form.key_val_sep2).c_str() - ); - return return_error(builder, start_pos, false); - } - if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2)); - } - } else { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep)); - } - } - auto val_start = builder.pos(); - - // Test if arg_val is a partial JSON - std::optional value_json = std::nullopt; - if (!form.raw_argval || !*form.raw_argval) { - try { value_json = builder.try_consume_json(); } - catch (const std::runtime_error&) { builder.move_to(val_start); } - // TODO: Delete this when json_partial adds top-level support for null/true/false - if (builder.pos() == val_start) { - const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)"); - builder.consume_spaces(); - std::string_view sv = utf8_truncate_safe_view(builder.input()); - sv.remove_prefix(builder.pos()); - std::string rest = "a"; - if (sv.size() < 6) rest = sv; - if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) { - value_json = {123, {"123", "123"}}; - builder.consume_rest(); - } else { - builder.move_to(val_start); - } - } - } - - // If it is a JSON and followed by , parse as json - // cannot support streaming because it may be a plain text starting with JSON - if (value_json) { - auto json_end = builder.pos(); - builder.consume_spaces(); - if (builder.pos() == builder.input().size()) { - if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) { - arguments[key] = value_json->json; - auto json_str = arguments.dump(); - if (!value_json->healing_marker.json_dump_marker.empty()) { - GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker)); - json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker)); - } else { - GGML_ASSERT(json_str.back() == '}'); - json_str.resize(json_str.size() - 1); - } - builder.add_tool_call(function_name, "", json_str); - } else { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - } - LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str()); - throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations."); - } - builder.move_to(json_end); - auto [val_end_size, tc] = try_find_val_end(); - if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) { - if (tc->groups[0].end - tc->groups[0].begin != val_end_size) { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str()); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : "")); - } else arguments[key] = value_json->json; - } else builder.move_to(val_start); - } - - // If not, parse as plain text - if (val_start == builder.pos()) { - if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) { - auto &value_str = value_plain->prelude; - if (form.trim_raw_argval) value_str = string_strip(value_str); - if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;}); - throw common_chat_msg_partial_exception( - "Expected " + gbnf_format_literal(form.val_end) + - " after " + gbnf_format_literal(form.key_val_sep) + - (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "") - ); - } - arguments[key] = value_str; - } else { - if (form.trim_raw_argval) { - gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;}); - } else { - gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;}); - } - throw common_chat_msg_partial_exception( - "Expected " + gbnf_format_literal(form.val_end) + - " after " + gbnf_format_literal(form.key_val_sep) + - (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "") - ); - } - } - } - - // Consume closing tag - if (auto [tool_end_size, tc] = try_find_tool_end(); tc) { - if (!all_space(tc->prelude)) { - LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", - gbnf_format_literal(form.tool_end).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - return return_error(builder, start_pos, recovery); - } - if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) { - // Add the parsed tool call - if (!builder.add_tool_call(function_name, "", arguments.dump())) { - throw common_chat_msg_partial_exception("Failed to add XML-Style tool call"); - } - recovery = false; - continue; - } - } - - auto tool_call_arg = arguments.dump(); - if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') { - tool_call_arg.resize(tool_call_arg.size() - 1); - } - builder.add_tool_call(function_name, "", tool_call_arg); - throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end)); - } - if (auto tc = builder.try_find_literal(form.scope_end)) { - if (!all_space(tc->prelude)) { - LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", - gbnf_format_literal(form.scope_end).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - return return_error(builder, start_pos, recovery); - } - } else { - if (all_space(form.scope_end)) return true; - builder.consume_spaces(); - if (builder.pos() == builder.input().size()) - throw common_chat_msg_partial_exception("incomplete tool calls"); - LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", - gbnf_format_literal(form.scope_end).c_str(), - gbnf_format_literal(builder.consume_rest()).c_str() - ); - return return_error(builder, start_pos, recovery); - } - - return true; -} - -/** - * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. - * May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client. - * form.scope_start, form.tool_sep and form.scope_end can be empty. - */ -bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) { - auto pos = pos_; - auto tsize = result_.tool_calls.size(); - try { return parse_xml_tool_calls(*this, form); } - catch (const xml_toolcall_syntax_exception&) {} - move_to(pos); - result_.tool_calls.resize(tsize); - return false; -} - -/** - * Parse content uses reasoning and XML-Style tool call - * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed. - */ -inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "", const std::string & end_think = "") { - constexpr auto rstrip = [](std::string &s) { - s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base())); - }; - // Erase substring from l to r, along with additional spaces nearby - constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) { - while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast(str[l]))); - ++l; - while (++r < str.size() && std::isspace(static_cast(str[r]))); - if (l < r) str[l] = '\n'; - if (l + 1 < r) str[l + 1] = '\n'; - if (l != 0) l += 2; - str.erase(l, r - l); - return l; - }; - constexpr auto trim_suffix = [](std::string &content, std::initializer_list list) { - auto best_match = content.size(); - for (auto pattern: list) { - if (pattern.size() == 0) continue; - for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) { - auto match_len = content.size() - match_idx; - if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) { - best_match = match_idx; - } - } - } - if (content.size() > best_match) { - content.erase(best_match); - } - }; - const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) { - return trim_suffix(content, { - start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start, - form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "", - form.val_end, form.last_val_end ? form.last_val_end->c_str() : "", - form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "", - form.scope_end - }); - }; - - - // Trim leading spaces without affecting keyword matching - static const common_regex spaces_regex("\\s*"); - { - auto tc = builder.consume_regex(spaces_regex); - auto spaces = builder.str(tc.groups[0]); - auto s1 = spaces.size(); - trim_potential_partial_word(spaces); - auto s2 = spaces.size(); - builder.move_to(builder.pos() - (s1 - s2)); - } - - // Parse content - bool reasoning_unclosed = builder.syntax().thinking_forced_open; - std::string unclosed_reasoning_content(""); - for (;;) { - auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start); - std::string content; - std::string tool_call_start; - - if (tc) { - content = std::move(tc->prelude); - tool_call_start = builder.str(tc->groups[0]); - LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str()); - } else { - content = builder.consume_rest(); - utf8_truncate_safe_resize(content); - } - - // Handle unclosed think block - if (reasoning_unclosed) { - if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) { - unclosed_reasoning_content += content; - if (!(form.allow_toolcall_in_think && tc)) { - unclosed_reasoning_content += tool_call_start; - continue; - } - } else { - reasoning_unclosed = false; - std::string reasoning_content; - if (pos == std::string::npos) { - reasoning_content = std::move(content); - } else { - reasoning_content = content.substr(0, pos); - content.erase(0, pos + end_think.size()); - } - if (builder.pos() == builder.input().size() && all_space(content)) { - rstrip(reasoning_content); - trim_potential_partial_word(reasoning_content); - rstrip(reasoning_content); - if (reasoning_content.empty()) { - rstrip(unclosed_reasoning_content); - trim_potential_partial_word(unclosed_reasoning_content); - rstrip(unclosed_reasoning_content); - if (unclosed_reasoning_content.empty()) continue; - } - } - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(start_think); - builder.add_content(unclosed_reasoning_content); - builder.add_content(reasoning_content); - if (builder.pos() != builder.input().size() || !all_space(content)) - builder.add_content(end_think); - } else { - builder.add_reasoning_content(unclosed_reasoning_content); - builder.add_reasoning_content(reasoning_content); - } - unclosed_reasoning_content.clear(); - } - } - - // Handle multiple think block - bool toolcall_in_think = false; - for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) { - if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) { - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { - auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size()); - builder.add_reasoning_content(reasoning_content); - think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1); - } else { - think_start = think_end + end_think.size() - 1; - } - } else { - // This start is in thinking block, skip this tool call - // This start is in thinking block - if (form.allow_toolcall_in_think) { - unclosed_reasoning_content = content.substr(think_start + start_think.size()); - } else { - unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start; - } - reasoning_unclosed = true; - content.resize(think_start); - toolcall_in_think = true; - } - } - - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { - rstrip(content); - // Handle unclosed token from content: delete all token - if (auto pos = content.rfind(end_think); pos != std::string::npos) { - while (pos != std::string::npos) { - pos = erase_spaces(content, pos, pos + end_think.size() - 1); - pos = content.rfind(end_think, pos); - } - } - // Strip if needed - if (content.size() > 0 && std::isspace(static_cast(content[0]))) { - content = string_strip(content); - } - } - - // remove potential partial suffix - if (builder.pos() == builder.input().size() && builder.is_partial()) { - if (unclosed_reasoning_content.empty()) { - rstrip(content); - trim_potential_partial_word(content); - rstrip(content); - } else { - rstrip(unclosed_reasoning_content); - trim_potential_partial_word(unclosed_reasoning_content); - rstrip(unclosed_reasoning_content); - } - } - - // consume unclosed_reasoning_content if allow_toolcall_in_think is set - if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) { - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { - builder.add_reasoning_content(unclosed_reasoning_content); - } else { - if (content.empty()) { - content = start_think + unclosed_reasoning_content; - } else { - content += "\n\n" + start_think; - content += unclosed_reasoning_content; - } - } - unclosed_reasoning_content.clear(); - } - - // Add content - if (!content.empty()) { - // If there are multiple content blocks - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) { - builder.add_content("\n\n"); - } - builder.add_content(content); - } - - // This start is in thinking block and toolcall_in_think not set, skip this tool call - if (toolcall_in_think && !form.allow_toolcall_in_think) { - continue; - } - - // There is no tool call and all content is parsed - if (!tc) { - GGML_ASSERT(builder.pos() == builder.input().size()); - GGML_ASSERT(unclosed_reasoning_content.empty()); - if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed); - break; - } - - builder.move_to(tc->groups[0].begin); - if (builder.try_consume_xml_tool_calls(form)) { - auto end_of_tool = builder.pos(); - builder.consume_spaces(); - if (builder.pos() != builder.input().size()) { - builder.move_to(end_of_tool); - if (!builder.result().content.empty()) { - builder.add_content("\n\n"); - } - } - } else { - static const common_regex next_char_regex("."); - auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]); - rstrip(c); - builder.add_content(c); - } - } -} - -/** - * Parse content uses reasoning and XML-Style tool call - */ -void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) { - parse_msg_with_xml_tool_calls(*this, form, start_think, end_think); -} diff --git a/common/chat-parser-xml-toolcall.h b/common/chat-parser-xml-toolcall.h deleted file mode 100644 index b309fb667..000000000 --- a/common/chat-parser-xml-toolcall.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include "chat.h" - -#include - -#include -#include -#include - - -// Sample config: -// MiniMax-M2 (left): \n\nvalue\n...\n... -// GLM 4.5 (right): function_name\nkey\nvalue\n -struct xml_tool_call_format { - std::string scope_start; // \n // \n // can be empty - std::string tool_start; // - std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls - std::string key_start; // - std::string key_val_sep; // \"> // \n - std::string val_end; // \n // \n - std::string tool_end; // \n // \n - std::string scope_end; // // // can be empty - // Set this if there can be dynamic spaces inside key_val_sep. - // e.g. key_val_sep= key_val_sep2= for GLM4.5 - std::optional key_val_sep2 = std::nullopt; - // Set true if argval should only be raw string. e.g. Hello "world" hi - // Set false if argval should only be json string. e.g. "Hello \"world\" hi" - // Defaults to std::nullopt, both will be allowed. - std::optional raw_argval = std::nullopt; - std::optional last_val_end = std::nullopt; - std::optional last_tool_end = std::nullopt; - bool trim_raw_argval = false; - bool allow_toolcall_in_think = false; -}; - -// make a GBNF that accept any strings except those containing any of the forbidden strings. -std::string make_gbnf_excluding(std::vector forbids); - -/** - * Build grammar for xml-style tool call - * form.scope_start and form.scope_end can be empty. - * Requires data.format for model-specific hacks. - */ -void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form); diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp deleted file mode 100644 index 060578f0b..000000000 --- a/common/chat-parser.cpp +++ /dev/null @@ -1,1649 +0,0 @@ -#include "chat-parser.h" -#include "chat-peg-parser.h" -#include "common.h" -#include "log.h" -#include "peg-parser.h" -#include "regex-partial.h" - -#include -#include -#include -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, - const common_regex & prefix, - size_t rstrip_prefix = 0) { - static const std::vector> args_paths = { { "arguments" } }; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json{ - { "code", code + builder.healing_marker() } - }) - .dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json{ - { "code", code } - }) - .dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = - nullptr) { - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : - function_regex ? builder.try_find_regex(*function_regex, from) : - std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) - : input_(input), is_partial_(is_partial), syntax_(syntax) -{ - result_.role = "assistant"; - - while (true) { - std::string id = std::to_string(std::rand()); - if (input.find(id) == std::string::npos) { - healing_marker_ = id; - break; - } - } -} - -std::string common_chat_msg_parser::str(const common_string_range & rng) const { - GGML_ASSERT(rng.begin <= rng.end); - return input_.substr(rng.begin, rng.end - rng.begin); -} - -void common_chat_msg_parser::add_content(const std::string &content) { - result_.content += content; -} - -void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) { - result_.reasoning_content += reasoning_content; -} - -bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) { - if (name.empty()) { - return false; - } - - common_chat_tool_call tool_call; - tool_call.name = name; - tool_call.arguments = arguments; - tool_call.id = id; - - // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str()); - result_.tool_calls.emplace_back(tool_call); - - return true; -} -bool common_chat_msg_parser::add_tool_call(const json & tool_call) { - std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; - std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; - std::string arguments = ""; - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else { - arguments = tool_call.at("arguments"); - } - } - - return add_tool_call(name, id, arguments); -} - -bool common_chat_msg_parser::add_tool_calls(const json & arr) { - for (const auto & item : arr) { - if (!add_tool_call(item)) { - return false; - } - } - return true; -} - -bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) { - if (!tool_call.is_object() || tool_call.size() != 1) { - return false; - } - - // Get the tool name (the single key in the object) - auto it = tool_call.begin(); - std::string name = it.key(); - - if (name.empty()) { - return false; - } - - // Get the arguments (the nested object) - const json & args_json = it.value(); - std::string arguments = ""; - - if (args_json.is_object()) { - arguments = args_json.dump(); - } else if (args_json.is_string()) { - arguments = args_json; - } else if (!args_json.is_null()) { - // For other types, convert to string representation - arguments = args_json.dump(); - } - - return add_tool_call(name, "", arguments); -} -void common_chat_msg_parser::finish() { - if (!is_partial_ && pos_ != input_.size()) { - throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); - } -} - -bool common_chat_msg_parser::consume_spaces() { - const auto length = input_.size(); - auto consumed = false; - while (pos_ < length && std::isspace(input_[pos_])) { - ++pos_; - consumed = true; - } - return consumed; -} - -bool common_chat_msg_parser::try_consume_literal(const std::string & literal) { - auto pos = pos_; - for (auto i = 0u; i < literal.size(); ++i) { - if (pos >= input_.size()) { - return false; - } - if (input_[pos] != literal[i]) { - return false; - } - ++pos; - } - pos_ = pos; - return true; -} - -std::optional common_chat_msg_parser::try_find_literal(const std::string & literal) { - auto idx = input_.find(literal, pos_); - if (idx != std::string::npos) { - find_regex_result res; - res.prelude = input_.substr(pos_, idx - pos_); - auto end = idx + literal.size(); - res.groups.emplace_back(common_string_range{idx, end}); - move_to(end); - return res; - } - if (is_partial_) { - idx = string_find_partial_stop(input_, literal); - if (idx != std::string::npos && idx >= pos_) { - find_regex_result res; - res.prelude = input_.substr(pos_, idx - pos_); - auto end = input_.size(); - res.groups.emplace_back(common_string_range{idx, end}); - move_to(end); - return res; - } - } - return std::nullopt; -} - -void common_chat_msg_parser::consume_literal(const std::string & literal) { - if (!try_consume_literal(literal)) { - throw common_chat_msg_partial_exception(literal); - } -} - -bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) { - std::string pending_reasoning_prefix; - - if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) { - return false; - } - - auto set_reasoning_prefix = [&](size_t prefix_pos) { - if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) { - return; - } - if (prefix_pos + start_think.size() > input_.size()) { - pending_reasoning_prefix.clear(); - return; - } - // Capture the exact literal that opened the reasoning section so we can - // surface it back to callers. This ensures formats that force the - // reasoning tag open (e.g. DeepSeek R1) retain their original prefix - // instead of dropping it during parsing. - pending_reasoning_prefix = input_.substr(prefix_pos, start_think.size()); - }; - - auto handle_reasoning = [&](const std::string & reasoning, bool closed) { - auto stripped_reasoning = string_strip(reasoning); - if (stripped_reasoning.empty()) { - return; - } - if (syntax_.reasoning_in_content) { - add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : start_think); - add_content(stripped_reasoning); - if (closed) { - add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think); - } - } else { - if (!pending_reasoning_prefix.empty()) { - add_reasoning_content(pending_reasoning_prefix); - pending_reasoning_prefix.clear(); - } - add_reasoning_content(stripped_reasoning); - } - }; - - const size_t saved_pos = pos_; - const size_t saved_content_size = result_.content.size(); - const size_t saved_reasoning_size = result_.reasoning_content.size(); - - auto restore_state = [&]() { - move_to(saved_pos); - result_.content.resize(saved_content_size); - result_.reasoning_content.resize(saved_reasoning_size); - }; - - // Allow leading whitespace to be preserved as content when reasoning is present at the start - size_t cursor = pos_; - size_t whitespace_end = cursor; - while (whitespace_end < input_.size() && std::isspace(static_cast(input_[whitespace_end]))) { - ++whitespace_end; - } - - if (whitespace_end >= input_.size()) { - restore_state(); - if (syntax_.thinking_forced_open) { - auto rest = input_.substr(saved_pos); - if (!rest.empty()) { - handle_reasoning(rest, /* closed */ !is_partial()); - } - move_to(input_.size()); - return true; - } - return false; - } - - cursor = whitespace_end; - const size_t remaining = input_.size() - cursor; - const size_t start_prefix = std::min(start_think.size(), remaining); - const bool has_start_tag = input_.compare(cursor, start_prefix, start_think, 0, start_prefix) == 0; - - if (has_start_tag && start_prefix < start_think.size()) { - move_to(input_.size()); - return true; - } - - if (has_start_tag) { - if (whitespace_end > pos_) { - add_content(input_.substr(pos_, whitespace_end - pos_)); - } - set_reasoning_prefix(cursor); - cursor += start_think.size(); - } else if (syntax_.thinking_forced_open) { - cursor = whitespace_end; - } else { - restore_state(); - return false; - } - while (true) { - if (cursor >= input_.size()) { - move_to(input_.size()); - return true; - } - - size_t end_pos = input_.find(end_think, cursor); - if (end_pos == std::string::npos) { - std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor); - size_t partial_off = string_find_partial_stop(remaining_view, end_think); - size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off; - if (reasoning_end > cursor) { - handle_reasoning(input_.substr(cursor, reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial()); - } - move_to(input_.size()); - return true; - } - - if (end_pos > cursor) { - handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true); - } else { - handle_reasoning("", /* closed */ true); - } - - cursor = end_pos + end_think.size(); - - while (cursor < input_.size() && std::isspace(static_cast(input_[cursor]))) { - ++cursor; - } - - const size_t next_remaining = input_.size() - cursor; - if (next_remaining == 0) { - move_to(cursor); - return true; - } - - const size_t next_prefix = std::min(start_think.size(), next_remaining); - if (input_.compare(cursor, next_prefix, start_think, 0, next_prefix) == 0) { - if (next_prefix < start_think.size()) { - move_to(input_.size()); - return true; - } - set_reasoning_prefix(cursor); - cursor += start_think.size(); - continue; - } - - move_to(cursor); - return true; - } -} - -std::string common_chat_msg_parser::consume_rest() { - auto rest = input_.substr(pos_); - pos_ = input_.size(); - return rest; -} - -// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback. -std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { - auto m = regex.search(input_, from == std::string::npos ? pos_ : from); - if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { - return std::nullopt; - } - auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); - pos_ = m.groups[0].end; - - if (add_prelude_to_content) { - add_content(prelude); - } - if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { - if (is_partial()) { - throw common_chat_msg_partial_exception(regex.str()); - } - return std::nullopt; - } - return find_regex_result{prelude, m.groups}; -} - -common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) { - if (auto result = try_consume_regex(regex)) { - return *result; - } - throw common_chat_msg_partial_exception(regex.str()); -} - -std::optional common_chat_msg_parser::try_consume_regex(const common_regex & regex) { - auto m = regex.search(input_, pos_); - if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { - return std::nullopt; - } - if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { - if (is_partial()) { - throw common_chat_msg_partial_exception(regex.str()); - } - return std::nullopt; - } - if (m.groups[0].begin != pos_) { - // Didn't match at the current position. - return std::nullopt; - } - pos_ = m.groups[0].end; - - return find_regex_result { - /* .prelude = */ "", - m.groups, - }; -} - -std::optional common_chat_msg_parser::try_consume_json() { - auto it = input_.cbegin() + pos_; - const auto end = input_.cend(); - common_json result; - if (!common_json_parse(it, end, healing_marker_, result)) { - return std::nullopt; - } - pos_ = std::distance(input_.cbegin(), it); - if (result.healing_marker.marker.empty()) { - // No healing marker, just return the parsed json - return result; - } - if (!is_partial()) { - throw common_chat_msg_partial_exception("JSON"); - } - return result; -} - -common_json common_chat_msg_parser::consume_json() { - if (auto result = try_consume_json()) { - return *result; - } - throw common_chat_msg_partial_exception("JSON"); -} - -common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args( - const std::vector> & args_paths, - const std::vector> & content_paths -) { - if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) { - return *result; - } - throw common_chat_msg_partial_exception("JSON"); -} - -std::optional common_chat_msg_parser::try_consume_json_with_dumped_args( - const std::vector> & args_paths, - const std::vector> & content_paths -) { - auto partial = try_consume_json(); - if (!partial) { - return std::nullopt; - } - auto is_arguments_path = [&](const std::vector & path) { - return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end(); - }; - auto is_content_path = [&](const std::vector & path) { - return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end(); - }; - - if (partial->healing_marker.marker.empty()) { - if (args_paths.empty()) { - // No arguments to dump, and JSON was parsed fully. - return consume_json_result { - partial->json, - /* .is_partial = */ false, - }; - } - if (is_arguments_path({})) { - // Entire JSON is the arguments and was parsed fully. - return consume_json_result { - partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true), - /* .is_partial = */ false, - }; - } - } - - LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); - - auto found_healing_marker = false; - std::vector path; - std::function remove_unsupported_healings_and_dump_args = [&](const json & j) -> json { - if (is_arguments_path(path)) { - auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true); - if (is_partial() && !partial->healing_marker.marker.empty()) { - auto idx = arguments.find(partial->healing_marker.json_dump_marker); - if (idx != std::string::npos) { - arguments.resize(idx); - found_healing_marker = true; - } - if (arguments == "\"") { - // This happens because of completing `:"$magic` after `"arguments"` - arguments = ""; - } - } - return arguments; - } - if (is_content_path(path)) { - if (!j.is_string()) { - throw std::runtime_error("Content path must be a string"); - } - std::string str = j; - auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string - if (idx != std::string::npos) { - str.resize(idx); - found_healing_marker = true; - } - return str; - } - if (j.is_object()) { - auto obj = json::object(); - for (const auto & p : j.items()) { - const auto & key = p.key(); - const auto & value = p.value(); - const std::string key_str = key; // NOLINT - auto idx = key_str.find(healing_marker_); - if (idx != std::string::npos) { - found_healing_marker = true; - break; - } - path.push_back(key_str); - if (value.is_string()) { - const std::string value_str = value; - if (value_str.find(healing_marker_) != std::string::npos) { - found_healing_marker = true; - if (is_content_path(path)) { - if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) { - // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair. - obj[key] = remove_unsupported_healings_and_dump_args(value); - } - } - break; - } - obj[key] = value; - } else { - obj[key] = remove_unsupported_healings_and_dump_args(value); - } - path.pop_back(); - } - return obj; - } - if (j.is_array()) { - auto arr = json::array(); - for (const auto & value : j) { - if (value.is_string()) { - std::string str = value; - auto idx = str.find(healing_marker_); - if (idx != std::string::npos) { - // Don't heal array values that aren't in the arguments. - found_healing_marker = true; - break; - } - } - arr.push_back(remove_unsupported_healings_and_dump_args(value)); - } - return arr; - } - return j; - }; - - auto cleaned = remove_unsupported_healings_and_dump_args(partial->json); - LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); - return consume_json_result { - cleaned, - /* .is_partial = */ found_healing_marker, - }; -} - -void common_chat_msg_parser::clear_tools() { - result_.tool_calls.clear(); -} - -/** - * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below - * to reduce incremental compile time for parser changes. - */ -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} - -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_magistral(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("[THINK]", "[/THINK]"); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} - -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "<|tool_calls_section_begin|>"; - form.tool_start = "<|tool_call_begin|>"; - form.tool_sep = "<|tool_call_argument_begin|>{"; - form.key_start = "\""; - form.key_val_sep = "\":"; - form.val_end = ","; - form.tool_end = "}<|tool_call_end|>"; - form.scope_end = "<|tool_calls_section_end|>"; - form.raw_argval = false; - form.last_val_end = ""; - form.allow_toolcall_in_think = true; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} - -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} - -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} - -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} - -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - builder.add_tool_calls(tool_calls_data.json); - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_apertus(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - builder.consume_spaces(); - if (!builder.try_consume_literal("<|tools_suffix|>")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - for (const auto & value : tool_calls_data.json) { - if (value.is_object()) { - builder.add_tool_call_short_form(value); - } - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - - -static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> - static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); - static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); - - // Loop through all tool calls - while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(res->groups[0].end); - - // Parse JSON array format: [{"name": "...", "arguments": {...}}] - auto tool_calls_data = builder.consume_json(); - - // Consume end marker - builder.consume_spaces(); - if (!builder.try_consume_regex(tool_call_end_regex)) { - throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); - } - - // Process each tool call in the array - if (tool_calls_data.json.is_array()) { - for (const auto & tool_call : tool_calls_data.json) { - if (!tool_call.is_object()) { - throw common_chat_msg_partial_exception("Tool call must be an object"); - } - - if (!tool_call.contains("name")) { - throw common_chat_msg_partial_exception("Tool call missing 'name' field"); - } - - std::string function_name = tool_call.at("name"); - std::string arguments = "{}"; - - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else if (tool_call.at("arguments").is_string()) { - arguments = tool_call.at("arguments"); - } - } - - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - } else { - throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); - } - - // Consume any trailing whitespace after this tool call - builder.consume_spaces(); - } - - // Consume any remaining content after all tool calls - auto remaining = builder.consume_rest(); - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } -} - -static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_solar_open(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>"); - - // TODO: Tool calling - - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_exaone_moe_content(common_chat_msg_parser & builder) { - // 1) { "name": "...", "arguments": {...} } - // 2) { "id": "...", "type": "function", "function": { "name": "...", "arguments": {...} } } - static const common_regex tool_call_open(R"(]*>)"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - // Find all blocks - while (auto first = builder.try_find_regex(tool_call_open, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(first->groups[0].end); - builder.consume_spaces(); - - builder.try_consume_literal("```json"); - builder.try_consume_literal("```"); - builder.consume_spaces(); - - // Consume JSON object - auto data = builder.consume_json(); - - builder.consume_spaces(); - builder.try_consume_literal("```"); - builder.consume_spaces(); - - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - - // Extract name and arguments - std::string name; - std::string id; - nlohmann::ordered_json arguments; - - const auto extract_args = [&](const nlohmann::ordered_json & obj) -> bool { - if (!obj.contains("name") || !obj.contains("arguments")) { - return false; - } - name = obj.at("name").get(); - arguments = obj.at("arguments"); - if (obj.contains("id") && obj.at("id").is_string()) { - id = obj.at("id").get(); - } - return true; - }; - - if (!extract_args(data.json)) { - if (data.json.contains("function") && data.json.at("function").is_object()) { - auto fn = data.json.at("function"); - extract_args(fn); - if (id.empty() && data.json.contains("id") && data.json.at("id").is_string()) { - id = data.json.at("id").get(); - } - } - } - - // If name is empty, treat the JSON object as content - if (name.empty()) { - LOG_DBG("%s: tool call missing name, treating as content\n", __func__); - builder.add_content(data.json.dump()); - continue; - } - - std::string args_str = arguments.dump(); - if (!builder.add_tool_call(name, id, args_str)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_exaone_moe(common_chat_msg_parser & builder) { - LOG_DBG("%s: parsing exaone_moe\n", __func__); - // EXAONE MoE outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - } - } -} - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_MAGISTRAL: - common_chat_parse_magistral(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_SEED_OSS: - common_chat_parse_seed_oss(builder); - break; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: - common_chat_parse_nemotron_v2(builder); - break; - case COMMON_CHAT_FORMAT_APERTUS: - common_chat_parse_apertus(builder); - break; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: - common_chat_parse_lfm2(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - case COMMON_CHAT_FORMAT_SOLAR_OPEN: - common_chat_parse_solar_open(builder); - break; - case COMMON_CHAT_FORMAT_EXAONE_MOE: - common_chat_parse_exaone_moe(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) { - if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE || - syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE || - syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { - return common_chat_peg_parse(syntax.parser, input, is_partial, syntax); - } - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG_DBG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} - -common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) { - if (parser.empty()) { - throw std::runtime_error("Failed to parse due to missing parser definition."); - } - - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str()); - - common_peg_parse_context ctx(input, is_partial); - auto result = parser.parse(ctx); - if (result.fail()) { - throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end)); - } - - common_chat_msg msg; - msg.role = "assistant"; - - if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) { - auto mapper = common_chat_peg_native_mapper(msg); - mapper.from_ast(ctx.ast, result); - } else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { - auto mapper = common_chat_peg_constructed_mapper(msg); - mapper.from_ast(ctx.ast, result); - } else { - // Generic mapper - auto mapper = common_chat_peg_mapper(msg); - mapper.from_ast(ctx.ast, result); - } - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} diff --git a/common/chat-parser.h b/common/chat-parser.h deleted file mode 100644 index 3ed9c30a2..000000000 --- a/common/chat-parser.h +++ /dev/null @@ -1,133 +0,0 @@ -#pragma once - -#include "chat.h" -#include "chat-parser-xml-toolcall.h" -#include "json-partial.h" -#include "regex-partial.h" - -#include - -#include -#include -#include - -class common_chat_msg_partial_exception : public std::runtime_error { - public: - common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {} -}; - -class common_chat_msg_parser { - std::string input_; - bool is_partial_; - common_chat_parser_params syntax_; // TODO: rename to params - std::string healing_marker_; - - size_t pos_ = 0; - common_chat_msg result_; - - public: - common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax); - const std::string & input() const { return input_; } - size_t pos() const { return pos_; } - const std::string & healing_marker() const { return healing_marker_; } - const bool & is_partial() const { return is_partial_; } - const common_chat_msg & result() const { return result_; } - const common_chat_parser_params & syntax() const { return syntax_; } - - void move_to(size_t pos) { - if (pos > input_.size()) { - throw std::runtime_error("Invalid position!"); - } - pos_ = pos; - } - void move_back(size_t n) { - if (pos_ < n) { - throw std::runtime_error("Can't move back that far!"); - } - pos_ -= n; - } - - // Get the substring of the input at the given range - std::string str(const common_string_range & rng) const; - - // Appends to the result.content field - void add_content(const std::string & content); - - // Appends to the result.reasoning_content field - void add_reasoning_content(const std::string & reasoning_content); - - // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything. - bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); - - // Adds a tool call using the "name", "id" and "arguments" fields of the json object - bool add_tool_call(const nlohmann::ordered_json & tool_call); - - // Adds an array of tool calls using their "name", "id" and "arguments" fields. - bool add_tool_calls(const nlohmann::ordered_json & arr); - - // Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } } - bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call); - - void finish(); - - bool consume_spaces(); - - void consume_literal(const std::string & literal); - - bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); - - std::string consume_rest(); - - struct find_regex_result { - std::string prelude; - std::vector groups; - }; - - std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); - - bool try_consume_literal(const std::string & literal); - - std::optional try_find_literal(const std::string & literal); - - find_regex_result consume_regex(const common_regex & regex); - - std::optional try_consume_regex(const common_regex & regex); - - std::optional try_consume_json(); - common_json consume_json(); - - struct consume_json_result { - nlohmann::ordered_json value; - bool is_partial; - }; - - /* - Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings. - - By default, object keys can't be truncated, nor can string values (their corresponding key is removed, - e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}` - - But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings - - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}` - - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}` - */ - consume_json_result consume_json_with_dumped_args( - const std::vector> & args_paths = {}, - const std::vector> & content_paths = {} - ); - std::optional try_consume_json_with_dumped_args( - const std::vector> & args_paths = {}, - const std::vector> & content_paths = {} - ); - - /** - * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. - * form.scope_start, form.tool_sep and form.scope_end can be empty. - */ - bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form); - - // Parse content uses reasoning and XML-Style tool call - void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "", const std::string & end_think = ""); - - void clear_tools(); -}; diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 5d43fa689..71c8b0a3e 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -1,13 +1,17 @@ #include "chat-peg-parser.h" +#include "chat-auto-parser.h" +#include "ggml.h" +#include "peg-parser.h" + #include -// using json = nlohmann::json; +// using json = nlohmann::ordered_json; static std::string_view trim_trailing_space(std::string_view sv, int max = -1) { int count = 0; while (!sv.empty() && std::isspace(static_cast(sv.back()))) { - if (max != -1 && count <= max) { + if (max != -1 && count >= max) { break; } sv.remove_suffix(1); @@ -16,109 +20,753 @@ static std::string_view trim_trailing_space(std::string_view sv, int max = -1) { return sv; } -void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { +static std::string_view trim_leading_space(std::string_view sv, int max = -1) { + int count = 0; + while (!sv.empty() && std::isspace(static_cast(sv.front()))) { + if (max != -1 && count >= max) { + break; + } + sv.remove_prefix(1); + count++; + } + return sv; +} + +static std::string_view trim(std::string_view sv) { + return trim_trailing_space(trim_leading_space(sv, 1)); +} + +// Count the number of unclosed '{' braces in a JSON-like string, +// properly skipping braces inside quoted strings. +static int json_brace_depth(const std::string & s) { + int depth = 0; + bool in_string = false; + bool escaped = false; + for (char c : s) { + if (escaped) { + escaped = false; + continue; + } + if (c == '\\' && in_string) { + escaped = true; + continue; + } + if (c == '"') { + in_string = !in_string; + continue; + } + if (!in_string) { + if (c == '{') { + depth++; + } else if (c == '}') { + depth--; + } + } + } + return depth; +} + +// JSON-escape a string and return the inner content (without surrounding quotes). +static std::string escape_json_string_inner(const std::string & s) { + std::string escaped = json(s).dump(); + if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') { + return escaped.substr(1, escaped.size() - 2); + } + return escaped; +} + +// Convert Python-style single-quoted strings to JSON double-quoted strings +// Only converts outer string delimiters, properly handling escape sequences: +// - {'key': 'value'} -> {"key": "value"} +// - {'code': 'print(\'hello\')'} -> {"code": "print('hello')"} +// - {'msg': 'He said "hi"'} -> {"msg": "He said \"hi\""} +static std::string normalize_quotes_to_json(const std::string & input) { + std::string result; + result.reserve(input.size() + 16); // May need extra space for escaping + + bool in_single_quoted = false; + bool in_double_quoted = false; + + for (size_t i = 0; i < input.size(); ++i) { + char c = input[i]; + + // Handle escape sequences + if (c == '\\' && i + 1 < input.size()) { + char next = input[i + 1]; + + if (in_single_quoted) { + // Inside a single-quoted string being converted to double quotes + if (next == '\'') { + // \' -> ' (escaped single quote becomes unescaped in double-quoted string) + result += '\''; + ++i; + continue; + } + if (next == '"') { + // \" stays as \" (already escaped, works in double-quoted string) + result += "\\\""; + ++i; + continue; + } + // Other escapes (\n, \\, etc.): pass through both characters + result += c; + result += next; + ++i; + continue; + } + + if (in_double_quoted) { + // Inside a double-quoted string - pass through escape sequences as-is + result += c; + result += next; + ++i; + continue; + } + + // Outside any string - just pass through the backslash + result += c; + continue; + } + + // Handle quote characters + if (c == '"') { + if (in_single_quoted) { + // Unescaped double quote inside single-quoted string -> must escape for JSON + result += "\\\""; + } else { + // Double quote as string delimiter or outside strings + in_double_quoted = !in_double_quoted; + result += c; + } + } else if (c == '\'') { + if (in_double_quoted) { + // Single quote inside double-quoted string -> pass through + result += c; + } else if (in_single_quoted) { + // Closing single quote -> convert to double quote + in_single_quoted = false; + result += '"'; + } else { + // Opening single quote -> convert to double quote + in_single_quoted = true; + result += '"'; + } + } else { + result += c; + } + } + + return result; +} + +void tag_based_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { arena.visit(result, [this](const common_peg_ast_node & node) { - map(node); + if (!node.tag.empty()) { + tags[node.tag] = std::string(node.text); + } }); } -void common_chat_peg_mapper::map(const common_peg_ast_node & node) { - bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; - bool is_content = node.tag == common_chat_peg_builder::CONTENT; +tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, bool is_partial) const { + common_peg_parse_context ctx(input, is_partial); + auto parse_result = arena.parse(ctx); - if (is_reasoning) { - result.reasoning_content = std::string(trim_trailing_space(node.text)); + tag_based_peg_mapper mapper; + mapper.from_ast(ctx.ast, parse_result); + + return { std::move(parse_result), std::move(mapper.tags) }; +} + +tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const { + if (input.empty()) { + return parse_and_extract(input, false); } + for (size_t i = 0; i < input.size(); i++) { + common_peg_parse_context ctx(input, false); + ctx.debug = debug; + auto parse_result = arena.parse(ctx, i); + if (parse_result.success() || i == input.size() - 1) { + tag_based_peg_mapper mapper; + mapper.from_ast(ctx.ast, parse_result); + return { std::move(parse_result), std::move(mapper.tags) }; + } + } + GGML_ABORT("Should not happen"); +} - if (is_content) { - result.content = std::string(trim_trailing_space(node.text)); +tagged_peg_parser build_tagged_peg_parser( + const std::function & fn) { + common_peg_parser_builder builder; + builder.set_root(fn(builder)); + return { builder.build() }; +} + +common_peg_parser common_chat_peg_builder::tag_with_safe_content(const std::string & tag_name, + const std::string & marker, + const common_peg_parser & p) { + if (marker.empty()) { + return zero_or_more(choice({ p, rule(tag_name, content(any())) })); + } + auto content_chunk = rule(tag_name, content(negate(literal(marker)) + any() + until(marker))); + return zero_or_more(choice({ p, content_chunk })); +} + +std::string & common_chat_peg_mapper::args_target() { + return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer; +} + +void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, + const common_peg_parse_result & parse_result_arg) { + arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); }); + // Flush any pending tool call that was started but never got a name + // This happens during partial parsing when the tool call is incomplete + if (pending_tool_call.has_value() && !pending_tool_call->name.empty()) { + if (!args_buffer.empty()) { + pending_tool_call->arguments = args_buffer; + } + if (closing_quote_pending && !pending_tool_call->arguments.empty()) { + pending_tool_call->arguments += "\""; + } + result.tool_calls.push_back(pending_tool_call.value()); + pending_tool_call.reset(); } } -void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) { - common_chat_peg_mapper::map(node); +void common_chat_peg_mapper::map(const common_peg_ast_node & node) { + // Handle reasoning/content tags + bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; + bool is_content = node.tag == common_chat_peg_builder::CONTENT; - bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN; - bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME; - bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID; - bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS; + if (is_reasoning) { // GPT OSS can have more than 1 reasoning block, so concatenate here + result.reasoning_content += std::string(node.text); + } + + if (is_content) { + // Concatenate content from multiple content nodes (e.g., when reasoning markers + // are preserved before content markers in reasoning_format=NONE mode) + result.content += std::string(node.text); + } + + // Handle tool-related tags (supporting both JSON and tagged formats) + bool is_tool_open = node.tag == common_chat_peg_builder::TOOL_OPEN; + bool is_tool_close = node.tag == common_chat_peg_builder::TOOL_CLOSE; + bool is_tool_name = node.tag == common_chat_peg_builder::TOOL_NAME; + bool is_tool_id = node.tag == common_chat_peg_builder::TOOL_ID; + bool is_tool_args = node.tag == common_chat_peg_builder::TOOL_ARGS; + bool is_arg_open = node.tag == common_chat_peg_builder::TOOL_ARG_OPEN; + bool is_arg_close = node.tag == common_chat_peg_builder::TOOL_ARG_CLOSE; + bool is_arg_name = node.tag == common_chat_peg_builder::TOOL_ARG_NAME; + bool is_arg_value = node.tag == common_chat_peg_builder::TOOL_ARG_VALUE; + bool is_arg_string_value = node.tag == common_chat_peg_builder::TOOL_ARG_STRING_VALUE; if (is_tool_open) { - result.tool_calls.emplace_back(); - current_tool = &result.tool_calls.back(); + pending_tool_call = common_chat_tool_call(); + current_tool = &pending_tool_call.value(); + arg_count = 0; + args_buffer.clear(); + closing_quote_pending = false; } if (is_tool_id && current_tool) { - current_tool->id = std::string(trim_trailing_space(node.text)); + auto text = trim_trailing_space(node.text); + if (text.size() >= 2 && text.front() == '"' && text.back() == '"') { + text = text.substr(1, text.size() - 2); + } + current_tool->id = std::string(text); } if (is_tool_name && current_tool) { current_tool->name = std::string(trim_trailing_space(node.text)); + // Now that we have the name, populate the arguments from the buffer + if (!args_buffer.empty()) { + current_tool->arguments = args_buffer; + args_buffer.clear(); + } else if (current_tool->arguments.empty()) { + current_tool->arguments = "{"; + } + // Add the tool call to results so streaming can see it + if (pending_tool_call.has_value()) { + result.tool_calls.push_back(pending_tool_call.value()); + pending_tool_call.reset(); + current_tool = &result.tool_calls.back(); + } } if (is_tool_args && current_tool) { - current_tool->arguments = std::string(trim_trailing_space(node.text)); - } -} - -void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { - common_chat_peg_mapper::map(node); - - bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN; - bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME; - bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE; - bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN; - bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE; - bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME; - bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE; - bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE; - - if (is_tool_open) { - result.tool_calls.emplace_back(); - current_tool = &result.tool_calls.back(); - arg_count = 0; - } - - if (is_tool_name) { - current_tool->name = std::string(node.text); - current_tool->arguments = "{"; + // For JSON format: arguments come as a complete JSON object + // For tagged format: built up from individual arg_name/arg_value nodes + auto text = trim_trailing_space(node.text); + if (!text.empty() && text.front() == '{') { + args_target() = std::string(text); + } } if (is_arg_open) { - needs_closing_quote = false; + closing_quote_pending = false; } if (is_arg_name && current_tool) { + std::string arg_entry; if (arg_count > 0) { - current_tool->arguments += ","; + arg_entry = ","; } - current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":"; + arg_entry += json(trim(node.text)).dump() + ":"; ++arg_count; + + auto & target = args_target(); + if (target.empty()) { + target = "{"; + } + target += arg_entry; } - if (is_arg_string && current_tool) { - // Serialize to JSON, but exclude the end quote - std::string dumped = json(trim_trailing_space(node.text)).dump(); - current_tool->arguments += dumped.substr(0, dumped.size() - 1); - needs_closing_quote = true; + if ((is_arg_value || is_arg_string_value) && current_tool) { + std::string value_content = std::string(trim_trailing_space(trim_leading_space(node.text, 1), 1)); + + std::string value_to_add; + if (value_content.empty() && is_arg_string_value) { + // Empty string value - arg_close will add the closing quote + value_to_add = "\""; + closing_quote_pending = true; + } else if (!value_content.empty() && is_arg_string_value) { + // Schema declares this as string type - always treat as literal string value + if (!closing_quote_pending) { + value_to_add = "\""; + closing_quote_pending = true; + } + value_to_add += escape_json_string_inner(value_content); + } else if (!value_content.empty()) { + // For potential containers, normalize Python-style single quotes to JSON double quotes + bool is_potential_container = value_content[0] == '[' || value_content[0] == '{'; + if (is_potential_container) { + value_content = normalize_quotes_to_json(value_content); + } + + // Try to parse as JSON value (number, bool, null, object, array) + try { + json parsed = json::parse(value_content); + if (parsed.is_string()) { + // Don't add closing quote yet (added by arg_close) for monotonic streaming + std::string escaped = parsed.dump(); + if (!escaped.empty() && escaped.back() == '"') { + escaped.pop_back(); + } + value_to_add = escaped; + closing_quote_pending = true; + } else { + // Non-string values: use raw content to preserve whitespace for monotonicity + value_to_add = value_content; + } + } catch (...) { + if (node.is_partial && is_potential_container) { + // Partial container: pass through the already-normalized content + value_to_add = value_content; + } else { + // Not valid JSON - treat as string value + if (!closing_quote_pending) { + value_to_add = "\""; + closing_quote_pending = true; + } + value_to_add += escape_json_string_inner(value_content); + } + } + } + + args_target() += value_to_add; } if (is_arg_close && current_tool) { - if (needs_closing_quote) { - current_tool->arguments += "\""; - needs_closing_quote = false; + if (closing_quote_pending) { + args_target() += "\""; + closing_quote_pending = false; } } - if (is_arg_json && current_tool) { - current_tool->arguments += std::string(trim_trailing_space(node.text)); - } - if (is_tool_close && current_tool) { - if (needs_closing_quote) { - current_tool->arguments += "\""; - needs_closing_quote = false; + // Flush buffer to arguments if tool name was never seen + if (current_tool->name.empty() && !args_buffer.empty()) { + current_tool->arguments = args_buffer; + args_buffer.clear(); + } + // Close any pending string quote + if (closing_quote_pending) { + current_tool->arguments += "\""; + closing_quote_pending = false; + } + // Close any unclosed braces (accounts for nested objects) + for (int d = json_brace_depth(current_tool->arguments); d > 0; d--) { + current_tool->arguments += "}"; + } + // Add tool call to results if named; otherwise discard + if (pending_tool_call.has_value()) { + if (!current_tool->name.empty()) { + result.tool_calls.push_back(pending_tool_call.value()); + } + pending_tool_call.reset(); } - current_tool->arguments += "}"; } } + +common_peg_parser common_chat_peg_builder::standard_constructed_tools( + const std::map & markers, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls) { + if (!tools.is_array() || tools.empty()) { + return eps(); + } + + // Extract markers with defaults + auto get_marker = [&markers](const std::string & key, const std::string & default_val = "") -> std::string { + auto it = markers.find(key); + return it != markers.end() ? it->second : default_val; + }; + + std::string section_start = get_marker("tool_call_start_marker", ""); + std::string section_end = get_marker("tool_call_end_marker", ""); + std::string func_opener = get_marker("function_opener", ""); + std::string func_closer = get_marker("function_closer", ""); + std::string param_key_prefix = get_marker("parameter_key_prefix", ""); + std::string param_closer = get_marker("parameter_closer", ""); + + // Build tool choices for tagged format + auto tool_choices = choice(); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; + } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + // Build argument parsers + auto args = eps(); + if (params.contains("properties") && !params["properties"].empty()) { + auto arg_choice = choice(); + for (const auto & el : params["properties"].items()) { + const std::string & prop_name = el.key(); + + auto arg_name_parser = + choice({ literal(prop_name), literal("\"" + prop_name + "\""), literal("'" + prop_name + "'") }); + + auto arg_rule = tool_arg(tool_arg_open(literal(param_key_prefix)) + tool_arg_name(arg_name_parser) + + literal(param_key_suffix) + tool_arg_value(until(param_closer)) + + tool_arg_close(literal(param_closer))); + arg_choice |= arg_rule; + } + args = zero_or_more(arg_choice + space()); + } + + // Build function parser: args + auto tool_parser = tool(tool_open(literal(func_opener) + tool_name(literal(name)) + literal(func_name_suffix)) + + space() + tool_args(args) + space() + tool_close(literal(func_closer))); + + tool_choices |= rule("tool-" + name, tool_parser); + } + + // Build the section with markers + auto section = + parallel_tool_calls ? + trigger_rule("tool-call", literal(section_start) + space() + one_or_more(tool_choices + space()) + + literal(section_end)) : + trigger_rule("tool-call", literal(section_start) + space() + tool_choices + space() + literal(section_end)); + + return force_tool_calls ? section : optional(section); +} + +// Helper: Parse dot notation key into prefix and field name +static std::pair parse_key_spec(const std::string & key) { + auto dot_pos = key.find('.'); + if (dot_pos == std::string::npos) { + return {"", key}; // Top-level field + } + return {key.substr(0, dot_pos), key.substr(dot_pos + 1)}; +} + +// Mode 1: function_is_key — parse {"function_name": {...}} +common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key( + const nlohmann::json & tools, + const std::string & args_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key) { + + auto tool_choices = choice(); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; + } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + // Build inner object fields + std::vector inner_fields; + + if (!call_id_key.empty()) { + auto id_parser = atomic( + literal("\"" + call_id_key + "\"") + space() + literal(":") + space() + + literal("\"") + tool_id(json_string_content()) + literal("\"") + ); + inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space()))); + } + + if (!gen_call_id_key.empty()) { + auto gen_id_parser = atomic( + literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + inner_fields.push_back(optional(gen_id_parser + space() + optional(literal(",") + space()))); + } + + // Arguments — either wrapped in args_key or parsed directly + common_peg_parser args_parser = eps(); + if (args_key.empty()) { + args_parser = tool_args(schema(json(), "tool-" + name + "-schema", params)); + } else { + args_parser = literal("\"" + effective_args_key + "\"") + space() + literal(":") + space() + + tool_args(schema(json(), "tool-" + name + "-schema", params)); + } + inner_fields.push_back(args_parser); + + // Build inner object parser + common_peg_parser inner_object = eps(); + if (args_key.empty() && inner_fields.size() == 1) { + inner_object = inner_fields[0]; + } else { + inner_object = literal("{") + space(); + for (size_t i = 0; i < inner_fields.size(); i++) { + inner_object = inner_object + inner_fields[i]; + if (i < inner_fields.size() - 1) { + inner_object = inner_object + space(); + } + } + inner_object = inner_object + space() + literal("}"); + } + + auto tool_parser = tool( + tool_open(literal("{")) + space() + + literal("\"") + tool_name(literal(name)) + literal("\"") + + space() + literal(":") + space() + + inner_object + + space() + tool_close(literal("}")) + ); + + tool_choices |= rule("tool-" + name, tool_parser); + } + + return tool_choices; +} + +// Mode 2: Nested keys (dot notation like "function.name") +common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys( + const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key) { + + auto tool_choices = choice(); + + auto name_spec = parse_key_spec(effective_name_key); + auto args_spec = parse_key_spec(effective_args_key); + + std::string nested_prefix = !name_spec.first.empty() ? name_spec.first : args_spec.first; + std::string nested_name_field = !name_spec.first.empty() ? name_spec.second : effective_name_key; + std::string nested_args_field = !args_spec.first.empty() ? args_spec.second : effective_args_key; + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; + } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() + + literal("\"") + tool_name(literal(name)) + literal("\""); + auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() + + tool_args(schema(json(), "tool-" + name + "-schema", params)); + + auto nested_object = literal("{") + space() + + nested_name + space() + literal(",") + space() + + nested_args + + space() + literal("}"); + + // Format: { id?, "function": {...} } + auto tool_parser_body = tool_open(literal("{")) + space(); + + if (!call_id_key.empty()) { + auto id_spec = parse_key_spec(call_id_key); + if (id_spec.first.empty()) { + auto id_parser = atomic( + literal("\"" + call_id_key + "\"") + space() + literal(":") + space() + + literal("\"") + tool_id(json_string_content()) + literal("\"") + ); + tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space()); + } + } + + if (!gen_call_id_key.empty()) { + auto gen_id_spec = parse_key_spec(gen_call_id_key); + if (gen_id_spec.first.empty()) { + auto gen_id_parser = atomic( + literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + tool_parser_body = tool_parser_body + optional(gen_id_parser + space() + literal(",") + space()); + } + } + + auto nested_field = literal("\"" + nested_prefix + "\"") + space() + literal(":") + space() + nested_object; + tool_parser_body = tool_parser_body + nested_field + space() + tool_close(literal("}")); + + tool_choices |= rule("tool-" + name, tool(tool_parser_body)); + } + + return tool_choices; +} + +// Mode 3: Flat keys with optional ID fields and parameter ordering +common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( + const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key, + const std::vector & parameters_order) { + + auto tool_choices = choice(); + auto name_key_parser = literal("\"" + effective_name_key + "\""); + auto args_key_parser = literal("\"" + effective_args_key + "\""); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; + } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + auto tool_name_ = name_key_parser + space() + literal(":") + space() + + literal("\"") + tool_name(literal(name)) + literal("\""); + auto tool_args_ = args_key_parser + space() + literal(":") + space() + + tool_args(schema(json(), "tool-" + name + "-schema", params)); + + // Build ID parsers if keys are provided + common_peg_parser id_parser = eps(); + if (!call_id_key.empty()) { + id_parser = atomic( + literal("\"" + call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + } + + common_peg_parser gen_id_parser = eps(); + if (!gen_call_id_key.empty()) { + gen_id_parser = atomic( + literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + } + + // Create (parser, key) pairs for all fields, then sort by parameters_order + std::vector> parser_pairs; + parser_pairs.emplace_back(tool_name_, effective_name_key); + parser_pairs.emplace_back(tool_args_, effective_args_key); + if (!call_id_key.empty()) { + parser_pairs.emplace_back(optional(id_parser), call_id_key); + } + if (!gen_call_id_key.empty()) { + parser_pairs.emplace_back(optional(gen_id_parser), gen_call_id_key); + } + + std::sort(parser_pairs.begin(), parser_pairs.end(), + [¶meters_order](const auto & a, const auto & b) { + auto pos_a = std::find(parameters_order.begin(), parameters_order.end(), a.second); + auto pos_b = std::find(parameters_order.begin(), parameters_order.end(), b.second); + size_t idx_a = (pos_a == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_a); + size_t idx_b = (pos_b == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_b); + return idx_a < idx_b; + }); + + auto ordered_body = tool_open(literal("{")) + space(); + for (size_t i = 0; i < parser_pairs.size(); i++) { + ordered_body = ordered_body + parser_pairs[i].first; + if (i < parser_pairs.size() - 1) { + ordered_body = ordered_body + space() + literal(",") + space(); + } + } + ordered_body = ordered_body + space() + tool_close(literal("}")); + + tool_choices |= rule("tool-" + name, tool(ordered_body)); + } + + return tool_choices; +} + +common_peg_parser common_chat_peg_builder::standard_json_tools( + const std::string & section_start, + const std::string & section_end, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls, + const std::string & name_key, + const std::string & args_key, + bool array_wrapped, + bool function_is_key, + const std::string & call_id_key, + const std::string & gen_call_id_key, + const std::vector & parameters_order) { + if (!tools.is_array() || tools.empty()) { + return eps(); + } + + std::string effective_name_key = name_key.empty() ? "name" : name_key; + std::string effective_args_key = args_key.empty() ? "arguments" : args_key; + + // Dispatch to the appropriate builder based on the JSON layout mode + common_peg_parser tool_choices = eps(); + if (function_is_key) { + tool_choices = build_json_tools_function_is_key(tools, args_key, effective_args_key, call_id_key, gen_call_id_key); + } else { + auto name_spec = parse_key_spec(effective_name_key); + auto args_spec = parse_key_spec(effective_args_key); + if (!name_spec.first.empty() || !args_spec.first.empty()) { + tool_choices = build_json_tools_nested_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key); + } else { + tool_choices = build_json_tools_flat_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key, parameters_order); + } + } + + // Build the section with markers + auto tool_calls = tool_choices; + if (parallel_tool_calls) { + tool_calls = tool_calls + zero_or_more(space() + literal(",") + space() + tool_choices); + } + + if (array_wrapped) { + tool_calls = literal("[") + space() + tool_calls + space() + literal("]"); + } + + auto section = + trigger_rule("tool-call", literal(section_start) + space() + tool_calls + space() + literal(section_end)); + + return force_tool_calls ? section : optional(section); +} diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index b84cbed20..e130ceea5 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -3,22 +3,9 @@ #include "chat.h" #include "peg-parser.h" -class common_chat_peg_builder : public common_peg_parser_builder { - public: - static constexpr const char * REASONING_BLOCK = "reasoning-block"; - static constexpr const char * REASONING = "reasoning"; - static constexpr const char * CONTENT = "content"; - - common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } - common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } - common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } -}; - -inline common_peg_arena build_chat_peg_parser(const std::function & fn) { - common_chat_peg_builder builder; - builder.set_root(fn(builder)); - return builder.build(); -} +#include +#include +#include class common_chat_peg_mapper { public: @@ -26,80 +13,164 @@ class common_chat_peg_mapper { common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {} + virtual ~common_chat_peg_mapper() = default; + virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); virtual void map(const common_peg_ast_node & node); + private: + // Tool call handling state + std::optional pending_tool_call; // Tool call waiting for name + common_chat_tool_call * current_tool = nullptr; + int arg_count = 0; + bool closing_quote_pending = false; + std::string args_buffer; // Buffer to delay arguments until tool name is known + + // Returns a reference to the active argument destination string. + // Before tool_name is known, writes go to args_buffer; after, to current_tool->arguments. + std::string & args_target(); }; -class common_chat_peg_native_builder : public common_chat_peg_builder { - public: - static constexpr const char * TOOL = "tool"; - static constexpr const char * TOOL_OPEN = "tool-open"; - static constexpr const char * TOOL_CLOSE = "tool-close"; - static constexpr const char * TOOL_ID = "tool-id"; - static constexpr const char * TOOL_NAME = "tool-name"; - static constexpr const char * TOOL_ARGS = "tool-args"; +struct content_structure; +struct tool_call_structure; +class common_chat_peg_builder : public common_peg_parser_builder { + public: + // Tag constants (from former common_chat_peg_base_builder) + static constexpr const char * REASONING_BLOCK = "reasoning-block"; + static constexpr const char * REASONING = "reasoning"; + static constexpr const char * CONTENT = "content"; + + // Tag constants + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_ID = "tool-id"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARGS = "tool-args"; + static constexpr const char * TOOL_ARG = "tool-arg"; + static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; + static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; + static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; + static constexpr const char * TOOL_ARG_VALUE = "tool-arg-value"; + static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; // For schema-declared string types + + // Low-level tag methods (from former common_chat_peg_base_builder) + common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } + + common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } + + common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } + + common_peg_parser tag_with_safe_content(const std::string & tag_name, + const std::string & marker, + const common_peg_parser & p); + + // Low-level tag methods common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); } common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); } -}; - -class common_chat_peg_native_mapper : public common_chat_peg_mapper { - common_chat_tool_call * current_tool; - - public: - common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} - - void map(const common_peg_ast_node & node) override; -}; - -inline common_peg_arena build_chat_peg_native_parser(const std::function & fn) { - common_chat_peg_native_builder builder; - builder.set_root(fn(builder)); - return builder.build(); -} - -class common_chat_peg_constructed_builder : public common_chat_peg_builder { - public: - static constexpr const char * TOOL = "tool"; - static constexpr const char * TOOL_OPEN = "tool-open"; - static constexpr const char * TOOL_CLOSE = "tool-close"; - static constexpr const char * TOOL_NAME = "tool-name"; - static constexpr const char * TOOL_ARG = "tool-arg"; - static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; - static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; - static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; - static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; - static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value"; - - common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } - common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } - common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } - common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); } common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); } common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); } common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); } + common_peg_parser tool_arg_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); } + + // Use for schema-declared string types - won't be treated as potential JSON container common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } - common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); } + + // Legacy-compatible helper for building standard JSON tool calls + // Used by tests and manual parsers + // name_key/args_key: JSON key names for function name and arguments + // Empty or "name"/"arguments" will accept both common variations + // Supports dot notation for nested objects (e.g., "function.name") + // array_wrapped: if true, tool calls are wrapped in JSON array [...] + // function_is_key: if true, function name is the JSON key (e.g., {"func_name": {...}}) + // call_id_key: JSON key for string call ID (e.g., "id") + // gen_call_id_key: JSON key for generated integer call ID (e.g., "tool_call_id") + // parameters_order: order in which JSON fields should be parsed + common_peg_parser standard_json_tools(const std::string & section_start, + const std::string & section_end, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls, + const std::string & name_key = "", + const std::string & args_key = "", + bool array_wrapped = false, + bool function_is_key = false, + const std::string & call_id_key = "", + const std::string & gen_call_id_key = "", + const std::vector & parameters_order = {}); + + // Legacy-compatible helper for building XML/tagged style tool calls + // Used by tests and manual parsers + common_peg_parser standard_constructed_tools(const std::map & markers, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls); + + private: + // Implementation helpers for standard_json_tools — one per JSON tool call layout mode + common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools, + const std::string & args_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key); + + common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key); + + common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key, + const std::vector & parameters_order); }; -class common_chat_peg_constructed_mapper : public common_chat_peg_mapper { - common_chat_tool_call * current_tool; - int arg_count = 0; - bool needs_closing_quote = false; - - public: - common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} - - void map(const common_peg_ast_node & node) override; -}; - -inline common_peg_arena build_chat_peg_constructed_parser(const std::function & fn) { - common_chat_peg_constructed_builder builder; - builder.set_root(fn(builder)); - return builder.build(); +inline common_peg_arena build_chat_peg_parser( + const std::function & fn) { + common_chat_peg_builder builder; + builder.set_root(fn(builder)); + return builder.build(); } + +class tag_based_peg_mapper { + public: + std::map tags; + + void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); +}; + +struct tagged_parse_result { + common_peg_parse_result result; + std::map tags; +}; + +struct tagged_peg_parser { + common_peg_arena arena; + bool debug = false; + + tagged_peg_parser & withDebug() { + debug = true; + return *this; + } + + tagged_peg_parser & withoutDebug() { + debug = false; + return *this; + } + + tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const; + tagged_parse_result parse_anywhere_and_extract(const std::string & input) const; +}; + +tagged_peg_parser build_tagged_peg_parser( + const std::function & fn); + diff --git a/common/chat.cpp b/common/chat.cpp index 7ebe5c632..2d9273527 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,33 +1,29 @@ #include "chat.h" -#include "chat-parser.h" + +#include "chat-auto-parser.h" #include "chat-peg-parser.h" #include "common.h" -#include "json-partial.h" +#include "ggml.h" #include "json-schema-to-grammar.h" #include "log.h" #include "json-partial.cpp" #include "regex-partial.cpp" #include "chat-parser-xml-toolcall.cpp" -#include "jinja/parser.h" #include "jinja/value.h" #include "jinja/runtime.h" #include "jinja/caps.h" - -#include "jinja/lexer.cpp" -#include "jinja/parser.cpp" -#include "jinja/runtime.cpp" -#include "jinja/value.cpp" -#include "jinja/string.cpp" -#include "jinja/caps.cpp" +#include "peg-parser.h" #include #include -#include +#include +#include #include #include -#include + #include +#include #include #include #include @@ -35,14 +31,26 @@ using json = nlohmann::ordered_json; static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { - auto time = std::chrono::system_clock::to_time_t(now); - auto local_time = *std::localtime(&time); + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); std::ostringstream ss; ss << std::put_time(&local_time, format.c_str()); auto res = ss.str(); return res; } +static json safe_args_parse(const std::string & to_parse) { + std::string stripped = to_parse; + if (to_parse.at(0) == '"' && to_parse.at(to_parse.length() - 1) == '"') { + stripped = to_parse.substr(1, to_parse.length() - 1); + } + try { + return json::parse(stripped); + } catch (json::exception & e) { + return stripped; + } +} + static std::string string_diff(const std::string & last, const std::string & current) { if (last.empty()) { return current; @@ -125,7 +133,7 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { {"type", "function"}, {"function", { {"name", tool_call.name}, - {"arguments", tool_call.arguments}, + {"arguments", json::parse(tool_call.arguments)}, }}, }; if (!tool_call.id.empty()) { @@ -142,7 +150,8 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { return jmsg; } -std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) { +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, + const common_chat_msg & msg_new) { std::vector diffs; if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) { diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3); @@ -152,38 +161,56 @@ std::vector common_chat_msg_diff::compute_diffs(const comm // TODO: these can become expensive for long messages - how to optimize? if (msg_prv.reasoning_content != msg_new.reasoning_content) { - auto & diff = diffs.emplace_back(); + auto & diff = diffs.emplace_back(); diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content); } if (msg_prv.content != msg_new.content) { - auto & diff = diffs.emplace_back(); + auto & diff = diffs.emplace_back(); diff.content_delta = string_diff(msg_prv.content, msg_new.content); } if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) { - throw std::runtime_error("Invalid diff: now finding less tool calls!"); + std::string err = "Invalid diff: now finding less tool calls!\n"; + err += " Previous (" + std::to_string(msg_prv.tool_calls.size()) + "):\n"; + for (const auto & tc : msg_prv.tool_calls) { + err += " - name: '" + tc.name + "', args: '" + tc.arguments + "'\n"; + } + err += " Current (" + std::to_string(msg_new.tool_calls.size()) + "):\n"; + for (const auto & tc : msg_new.tool_calls) { + err += " - name: '" + tc.name + "', args: '" + tc.arguments + "'\n"; + } + err += " Current msg text content:\n" + msg_new.content + "\n"; + throw std::runtime_error(err); } if (!msg_prv.tool_calls.empty()) { - const auto idx = msg_prv.tool_calls.size() - 1; + const auto idx = msg_prv.tool_calls.size() - 1; const auto & pref = msg_prv.tool_calls[idx]; const auto & newf = msg_new.tool_calls[idx]; - if (pref.name != newf.name) { - throw std::runtime_error("Invalid diff: tool call mismatch!"); + // Allow tool name to change during incremental parsing: + // - empty -> non-empty (initial discovery) + // - prefix -> longer string (name grows as more input is parsed) + if (pref.name != newf.name && !pref.name.empty() && !newf.name.empty()) { + // Check if one is a prefix of the other (for incremental parsing where names grow or shrink) + bool is_prefix = (newf.name.rfind(pref.name, 0) == 0); + if (!is_prefix) { + LOG_ERR("Tool call mismatch: prev='%s' new='%s'\n", pref.name.c_str(), newf.name.c_str()); + throw std::runtime_error("Invalid diff: tool call mismatch!"); + } } const auto args_diff = string_diff(pref.arguments, newf.arguments); - if (!args_diff.empty() || pref.id != newf.id) { - auto & diff = diffs.emplace_back(); + if (!args_diff.empty() || pref.id != newf.id || pref.name != newf.name) { + auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; - if (pref.id != newf.id) { - diff.tool_call_delta.id = newf.id; + if (pref.id != newf.id || pref.name != newf.name) { + diff.tool_call_delta.id = newf.id; diff.tool_call_delta.name = newf.name; } diff.tool_call_delta.arguments = args_diff; } } for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) { - auto & diff = diffs.emplace_back(); + auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; diff.tool_call_delta = msg_new.tool_calls[idx]; } @@ -193,94 +220,14 @@ std::vector common_chat_msg_diff::compute_diffs(const comm using chat_template_caps = jinja::caps; -struct common_chat_template { - jinja::program prog; - std::string bos_tok; - std::string eos_tok; - std::string src; - chat_template_caps caps; - - common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { - jinja::lexer lexer; - auto lexer_res = lexer.tokenize(src); - this->prog = jinja::parse_from_tokens(lexer_res); - - this->src = lexer_res.source; - this->bos_tok = bos_token; - this->eos_tok = eos_token; - - this->caps = jinja::caps_get(prog); - // LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str()); - } - - const std::string & source() const { return src; } - const std::string & bos_token() const { return bos_tok; } - const std::string & eos_token() const { return eos_tok; } - - // TODO: this is ugly, refactor it somehow - json add_system(const json & messages, const std::string & system_prompt) const { - GGML_ASSERT(messages.is_array()); - auto msgs_copy = messages; - if (!caps.supports_system_role) { - if (msgs_copy.empty()) { - msgs_copy.insert(msgs_copy.begin(), json{ - {"role", "user"}, - {"content", system_prompt} - }); - } else { - auto & first_msg = msgs_copy[0]; - if (!first_msg.contains("content")) { - first_msg["content"] = ""; - } - first_msg["content"] = system_prompt + "\n\n" - + first_msg["content"].get(); - } - } else { - if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") { - msgs_copy.insert(msgs_copy.begin(), json{ - {"role", "system"}, - {"content", system_prompt} - }); - } else if (msgs_copy[0].at("role") == "system") { - msgs_copy[0]["content"] = system_prompt; - } - } - return msgs_copy; - } - - chat_template_caps original_caps() const { - return caps; - } - -}; - struct common_chat_templates { bool add_bos; bool add_eos; - bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr template_default; // always set (defaults to chatml) + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) std::unique_ptr template_tool_use; }; -struct templates_params { - json messages; - json tools; - common_chat_tool_choice tool_choice; - json json_schema; - bool parallel_tool_calls; - common_reasoning_format reasoning_format; - bool stream; - std::string grammar; - bool add_generation_prompt = true; - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - json extra_context; - bool add_bos; - bool add_eos; - bool is_inference = true; - bool mark_input = true; // whether to mark input strings in the jinja context -}; - common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { if (tool_choice == "auto") { return COMMON_CHAT_TOOL_CHOICE_AUTO; @@ -295,23 +242,24 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin } bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) { - common_chat_templates_inputs dummy_inputs; + common_chat_templates_inputs inputs; + inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; common_chat_msg msg; - msg.role = "user"; + msg.role = "user"; msg.content = "test"; - dummy_inputs.messages = {msg}; - dummy_inputs.enable_thinking = false; - const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); - dummy_inputs.enable_thinking = true; - const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); - return rendered_no_thinking.prompt != rendered_with_thinking.prompt; + inputs.messages = { msg }; + inputs.enable_thinking = true; + inputs.add_generation_prompt = true; + inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + + auto params = common_chat_templates_apply(chat_templates, inputs); + return params.supports_thinking; } std::vector common_chat_msgs_parse_oaicompat(const json & messages) { std::vector msgs; try { - if (!messages.is_array()) { throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump()); } @@ -327,7 +275,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } msg.role = message.at("role"); - auto has_content = message.contains("content"); + auto has_content = message.contains("content"); auto has_tool_calls = message.contains("tool_calls"); if (has_content) { const auto & content = message.at("content"); @@ -348,7 +296,9 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + + content.dump() + + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } if (has_tool_calls) { @@ -368,8 +318,13 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa if (!fc.contains("name")) { throw std::invalid_argument("Missing tool call name: " + tool_call.dump()); } - tc.name = fc.at("name"); - tc.arguments = fc.at("arguments"); + tc.name = fc.at("name"); + const auto & args = fc.at("arguments"); + if (args.is_string()) { + tc.arguments = args; + } else { + tc.arguments = args.dump(); + } if (tool_call.contains("id")) { tc.id = tool_call.at("id"); } @@ -377,7 +332,9 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } } if (!has_content && !has_tool_calls) { - throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + throw std::invalid_argument( + "Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & " + "https://github.com/ggml-org/llama.cpp/issues/12279)"); } if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); @@ -483,12 +440,13 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t auto result = json::array(); for (const auto & tool : tools) { result.push_back({ - {"type", "function"}, - {"function", { - {"name", tool.name}, - {"description", tool.description}, - {"parameters", json::parse(tool.parameters)}, - }}, + { "type", "function" }, + { "function", + { + { "name", tool.name }, + { "description", tool.description }, + { "parameters", json::parse(tool.parameters) }, + } }, }); } return result; @@ -506,16 +464,20 @@ json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json tool_call; tool_call["index"] = diff.tool_call_index; if (!diff.tool_call_delta.id.empty()) { - tool_call["id"] = diff.tool_call_delta.id; + tool_call["id"] = diff.tool_call_delta.id; tool_call["type"] = "function"; } - json function = json::object(); - if (!diff.tool_call_delta.name.empty()) { - function["name"] = diff.tool_call_delta.name; + if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + if (!diff.tool_call_delta.arguments.empty()) { + function["arguments"] = diff.tool_call_delta.arguments; + } + tool_call["function"] = function; } - function["arguments"] = diff.tool_call_delta.arguments; - tool_call["function"] = function; - delta["tool_calls"] = json::array({tool_call}); + delta["tool_calls"] = json::array({ tool_call }); } return delta; } @@ -529,13 +491,13 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { common_chat_msg msg; - msg.role = "user"; + msg.role = "user"; msg.content = "test"; auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); common_chat_templates_inputs inputs; - inputs.messages = {msg}; + inputs.messages = { msg }; common_chat_templates_apply(tmpls.get(), inputs); return true; @@ -544,28 +506,28 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return false; } } - llama_chat_message chat[] = {{"user", "test"}}; + llama_chat_message chat[] = { + { "user", "test" } + }; const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } -std::string common_chat_format_single( - const struct common_chat_templates * tmpls, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja) { - +std::string common_chat_format_single(const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; - inputs.add_bos = tmpls->add_bos; - inputs.add_eos = tmpls->add_eos; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; std::string fmt_past_msg; if (!past_msg.empty()) { - inputs.messages = past_msg; + inputs.messages = past_msg; inputs.add_generation_prompt = false; - fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; + fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; } std::ostringstream ss; // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -575,37 +537,39 @@ std::string common_chat_format_single( // format chat with new_msg inputs.messages.push_back(new_msg); inputs.add_generation_prompt = add_ass; - auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja, const std::map & chat_template_kwargs) { +std::string common_chat_format_example(const struct common_chat_templates * tmpls, + bool use_jinja, + const std::map & chat_template_kwargs) { common_chat_templates_inputs inputs; - inputs.use_jinja = use_jinja; - inputs.add_bos = tmpls->add_bos; - inputs.add_eos = tmpls->add_eos; + inputs.use_jinja = use_jinja; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; inputs.chat_template_kwargs = chat_template_kwargs; - auto add_simple_msg = [&](auto role, auto content) { + auto add_simple_msg = [&](auto role, auto content) { common_chat_msg msg; - msg.role = role; + msg.role = role; msg.content = content; inputs.messages.push_back(msg); }; - add_simple_msg("system", "You are a helpful assistant"); - add_simple_msg("user", "Hello"); + add_simple_msg("system", "You are a helpful assistant"); + add_simple_msg("user", "Hello"); add_simple_msg("assistant", "Hi there"); - add_simple_msg("user", "How are you?"); + add_simple_msg("user", "How are you?"); return common_chat_templates_apply(tmpls, inputs).prompt; } -#define CHATML_TEMPLATE_SRC \ - "{%- for message in messages -%}\n" \ +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ - "{%- endfor -%}\n" \ - "{%- if add_generation_prompt -%}\n" \ - " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ "{%- endif -%}" void common_chat_templates_free(struct common_chat_templates * tmpls) { @@ -623,19 +587,16 @@ std::string common_chat_templates_source(const struct common_chat_templates * tm return tmpls->template_tool_use->source(); } return ""; - } else { - LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str()); } + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str()); } return tmpls->template_default->source(); } -common_chat_templates_ptr common_chat_templates_init( - const struct llama_model * model, - const std::string & chat_template_override, - const std::string & bos_token_override, - const std::string & eos_token_override) -{ +common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override, + const std::string & eos_token_override) { std::string default_template_src; std::string template_tool_use_src; @@ -644,7 +605,7 @@ common_chat_templates_ptr common_chat_templates_init( GGML_ASSERT(model != nullptr); const auto * str = llama_model_chat_template(model, /* name */ nullptr); if (str) { - default_template_src = str; + default_template_src = str; has_explicit_template = true; } str = llama_model_chat_template(model, /* name */ "tool_use"); @@ -666,34 +627,40 @@ common_chat_templates_ptr common_chat_templates_init( // TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error // Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633 if (default_template_src.find("<|channel|>") != std::string::npos - // search for the error message and patch it - && default_template_src.find("in message.content or") != std::string::npos) { + // search for the error message and patch it + && default_template_src.find("in message.content or") != std::string::npos) { string_replace_all(default_template_src, - "{%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}", - "{%- if false %}"); + "{%- if \"<|channel|>analysis<|message|>\" in message.content or " + "\"<|channel|>final<|message|>\" in message.content %}", + "{%- if false %}"); } // TODO @aldehir : this is a temporary fix, pending Minja changes // Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664 if (default_template_src.find("[TOOL_CALLS]") != std::string::npos - // search for the error message and patch it - && default_template_src.find("if (message['content'] is none or") != std::string::npos) { + // search for the error message and patch it + && default_template_src.find("if (message['content'] is none or") != std::string::npos) { string_replace_all(default_template_src, - "{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}", - "{%- if false %}"); + "{%- if (message['content'] is none or message['content'] == '' or " + "message['content']|length == 0) and (message['tool_calls'] is not defined or " + "message['tool_calls'] is none or message['tool_calls']|length == 0) %}", + "{%- if false %}"); } std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; - bool add_bos = false; - bool add_eos = false; + bool add_bos = false; + bool add_eos = false; if (model) { - const auto * vocab = llama_model_get_vocab(model); - const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + const auto * vocab = llama_model_get_vocab(model); + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { if (token == LLAMA_TOKEN_NULL) { - if (default_template_src.find(jinja_variable_name) != std::string::npos - || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { - LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name); + if (default_template_src.find(jinja_variable_name) != std::string::npos || + template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN( + "common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't " + "work as intended.\n", + name); } return std::string(); } @@ -701,13 +668,13 @@ common_chat_templates_ptr common_chat_templates_init( }; token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); - add_bos = llama_vocab_get_add_bos(vocab); - add_eos = llama_vocab_get_add_eos(vocab); + add_bos = llama_vocab_get_add_bos(vocab); + add_eos = llama_vocab_get_add_eos(vocab); } common_chat_templates_ptr tmpls(new common_chat_templates()); tmpls->has_explicit_template = has_explicit_template; - tmpls->add_bos = add_bos; - tmpls->add_eos = add_eos; + tmpls->add_bos = add_bos; + tmpls->add_eos = add_eos; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { @@ -728,35 +695,12 @@ common_chat_templates_ptr common_chat_templates_init( const char * common_chat_format_name(common_chat_format format) { switch (format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; - case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; - case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral"; - case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1"; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; - case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; - case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; - case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; - case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; - case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; - case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2"; - case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5"; - case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2"; - case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; - case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open"; - case COMMON_CHAT_FORMAT_EXAONE_MOE: return "EXAONE MoE"; - case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; - case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; - case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + return "Content-only"; + case COMMON_CHAT_FORMAT_PEG_SIMPLE: + return "peg-simple"; + case COMMON_CHAT_FORMAT_PEG_NATIVE: + return "peg-native"; default: throw std::runtime_error("Unknown chat format"); } @@ -764,10 +708,14 @@ const char * common_chat_format_name(common_chat_format format) { const char * common_reasoning_format_name(common_reasoning_format format) { switch (format) { - case COMMON_REASONING_FORMAT_NONE: return "none"; - case COMMON_REASONING_FORMAT_AUTO: return "auto"; - case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; - case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + case COMMON_REASONING_FORMAT_NONE: + return "none"; + case COMMON_REASONING_FORMAT_AUTO: + return "auto"; + case COMMON_REASONING_FORMAT_DEEPSEEK: + return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: + return "deepseek-legacy"; default: throw std::runtime_error("Unknown reasoning format"); } @@ -776,11 +724,14 @@ const char * common_reasoning_format_name(common_reasoning_format format) { common_reasoning_format common_reasoning_format_from_name(const std::string & format) { if (format == "none") { return COMMON_REASONING_FORMAT_NONE; - } else if (format == "auto") { + } + if (format == "auto") { return COMMON_REASONING_FORMAT_AUTO; - } else if (format == "deepseek") { + } + if (format == "deepseek") { return COMMON_REASONING_FORMAT_DEEPSEEK; - } else if (format == "deepseek-legacy") { + } + if (format == "deepseek-legacy") { return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } throw std::runtime_error("Unknown reasoning format: " + format); @@ -796,7 +747,8 @@ static void foreach_function(const json & tools, const std::function & fn) { +static void foreach_parameter(const json & function, + const std::function & fn) { if (!function.contains("parameters") || !function.at("parameters").is_object()) { return; } @@ -804,7 +756,7 @@ static void foreach_parameter(const json & function, const std::function required; if (params.contains("required") && params.at("required").is_array()) { params.at("required").get_to(required); @@ -815,19 +767,19 @@ static void foreach_parameter(const json & function, const std::function & messages_override = std::nullopt, - const std::optional & tools_override = std::nullopt, - const std::optional & additional_context = std::nullopt) -{ + const autoparser::templates_params & inputs, + const std::optional & messages_override, + const std::optional & tools_override, + const std::optional & additional_context) { jinja::context ctx(tmpl.source()); nlohmann::ordered_json inp = nlohmann::ordered_json{ {"messages", messages_override.has_value() ? *messages_override : inputs.messages}, {"bos_token", tmpl.bos_token()}, {"eos_token", tmpl.eos_token()}, + {"enable_thinking", inputs.enable_thinking}, }; if (tools_override.has_value() || !inputs.tools.empty()) { inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools; @@ -853,7 +805,7 @@ static std::string apply( // render jinja::runtime runtime(ctx); const jinja::value results = runtime.execute(tmpl.prog); - auto parts = runtime.gather_string_parts(results); + auto parts = jinja::runtime::gather_string_parts(results); std::string result = parts->as_string().str(); @@ -867,265 +819,8 @@ static std::string apply( return result; } -static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - auto tool_call_schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - auto tool_schema = json { - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - }}, - {"required", json::array({"name", "arguments"})}, - }; - if (function.contains("description")) { - tool_schema["description"] = function.at("description"); - } - if (inputs.parallel_tool_calls) { - tool_schema.at("properties")["id"] = { - {"type", "string"}, - {"minLength", 4}, - }; - tool_schema.at("required").push_back("id"); - } - tool_call_schemas.emplace_back(tool_schema); - }); - const auto tool_call = - inputs.parallel_tool_calls - ? json { - {"type", "object"}, - {"properties", { - {"tool_calls", { - {"type", "array"}, - {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { - {"anyOf", tool_call_schemas}, - }}, - {"minItems", 1}, - }}, - }}, - {"required", json::array({"tool_calls"})}, - } - : json { - {"type", "object"}, - {"properties", { - {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { - {"anyOf", tool_call_schemas}, - }}, - }}, - {"required", json::array({"tool_call"})}, - }; - const auto schema = - inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED - ? json { - {"anyOf", json::array({ - tool_call, - { - {"type", "object"}, - {"properties", { - {"response", inputs.json_schema.is_null() - ? json {{"type", "string"}} - : inputs.json_schema - }, - }}, - {"required", json::array({"response"})}, - }, - })} - } - : tool_call; - - data.grammar_lazy = false; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - builder.add_schema("root", schema); - }); - - auto tweaked_messages = tmpl.add_system( - inputs.messages, - "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); - - // ensure all messages has "content" field - for (auto & message : tweaked_messages) { - if (!message.contains("content") || message["content"].is_null()) { - message["content"] = ""; - } - } - - data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); - data.format = COMMON_CHAT_FORMAT_GENERIC; - return data; -} - -static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - // Important note: the model is probably trained to take a JSON stringified arguments value. - // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - {"id", { - {"type", "string"}, - // Nemo's template expects a 9-character alphanumeric ID. - {"pattern", "^[a-zA-Z0-9]{9}$"}, - }}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); - }); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); - data.preserved_tokens = { - "[TOOL_CALLS]", - }; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; - return data; -} - - -// Case-insensitive find -static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { - auto it = std::search( - haystack.begin() + pos, haystack.end(), - needle.begin(), needle.end(), - [](char a, char b) { return std::tolower(a) == std::tolower(b); } - ); - return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); -} - -static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - const auto is_json_schema_provided = !inputs.json_schema.is_null(); - const auto is_grammar_provided = !inputs.grammar.empty(); - const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); - - // the logic requires potentially modifying the messages - auto tweaked_messages = inputs.messages; - - auto replace_json_schema_marker = [](json & messages) -> bool { - static std::string marker1 = "force json schema.\n"; - static std::string marker2 = "force json schema."; - - if (messages.empty() || messages.at(0).at("role") != "system") { - return false; - } - - std::string content = messages.at(0).at("content"); - - for (const auto & marker : {marker1, marker2}) { - const auto pos = ifind_string(content, marker); - if (pos != std::string::npos) { - content.replace(pos, marker.length(), ""); - // inject modified content back into the messages - messages.at(0).at("content") = content; - return true; - } - } - - return false; - }; - - // Lfm2 model does not natively work with json, but can generally understand the tools structure - // - // Example of the pytorch dialog structure: - // <|startoftext|><|im_start|>system - // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> - // <|im_start|>user - // What is the current status of candidate ID 12345?<|im_end|> - // <|im_start|>assistant - // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> - // <|im_start|>tool - // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> - // <|im_start|>assistant - // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> - // - // For the llama server compatibility with json tools semantic, - // the client can add "Follow json schema." line into the system message prompt to force the json output. - // - if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { - // server/utils.hpp prohibits that branch for the custom grammar anyways - throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); - } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { - LOG_INF("%s: Using tools to build a grammar\n", __func__); - - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - - builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); - }); - // model has no concept of tool selection mode choice, - // if the system prompt rendered correctly it will produce a tool call - // the grammar goes inside the tool call body - data.grammar_lazy = true; - data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; - data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; - data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; - } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { - LOG_INF("%s: Using tools without json schema or grammar\n", __func__); - // output those tokens - data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; - } else if (is_json_schema_provided) { - LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); - data.grammar = json_schema_to_grammar(inputs.json_schema); - } else if (is_grammar_provided) { - LOG_INF("%s: Using provided grammar\n", __func__); - data.grammar = inputs.grammar; - } else { - LOG_INF("%s: Using content relying on the template\n", __func__); - } - - data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); - LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); - - return data; -} - -static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const struct templates_params & inputs) { +static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja @@ -1143,8 +838,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ // If message contains `reasoning_content`, add it as a block of type `thinking` if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { content.push_back({ - {"type", "thinking"}, - {"thinking", msg.at("reasoning_content").get()}, + { "type", "thinking" }, + { "thinking", msg.at("reasoning_content").get() }, }); } @@ -1152,8 +847,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ if (msg.contains("content")) { if (msg.at("content").is_string()) { content.push_back({ - {"type", "text"}, - {"text", msg.at("content").get()}, + { "type", "text" }, + { "text", msg.at("content").get() }, }); } else if (msg.at("content").is_array()) { auto blocks = msg.at("content"); @@ -1161,32 +856,35 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ } } - auto adjusted = msg; + auto adjusted = msg; adjusted["content"] = content; adjusted.erase("reasoning_content"); adjusted_messages.push_back(adjusted); } - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; - auto include_grammar = true; + auto include_grammar = true; - data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { + data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { "[THINK]", "[/THINK]", "[TOOL_CALLS]", "[ARGS]", }; - auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { - auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto reasoning = + extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); // Response format parser if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { // Ministral wants to emit json surrounded by code fences - return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"; + return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + << "```"; } // Tool call parser @@ -1194,17 +892,16 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ auto tool_choice = p.choice(); foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - std::string name = function.at("name"); - const auto & schema = function.at("parameters"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); - tool_choice |= p.rule("tool-" + name, - p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") - + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) - ); + tool_choice |= + p.rule("tool-" + name, p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); }); - auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; - auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; @@ -1223,820 +920,40 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ data.grammar = build_grammar([&](const common_grammar_builder & builder) { foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - auto schema = function.at("parameters"); + auto schema = function.at("parameters"); builder.resolve_refs(schema); }); parser.build_grammar(builder, data.grammar_lazy); }); data.grammar_triggers = { - {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"} + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]" } }; } return data; } -static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_MAGISTRAL; - data.preserved_tokens = { - "[THINK]", - "[/THINK]", - }; - - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - {"id", { - {"type", "string"}, - {"pattern", "^[a-zA-Z0-9]{9}$"}, - }}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); - }); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); - data.preserved_tokens.push_back("[TOOL_CALLS]"); - } else { - data.grammar_lazy = false; - if (!inputs.json_schema.is_null()) { - if (!inputs.grammar.empty()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } - data.grammar = json_schema_to_grammar(inputs.json_schema); - } else { - data.grammar = inputs.grammar; - } - } - - return data; -} - -static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - auto adjusted_messages = json::array(); - for (const auto & msg : inputs.messages) { - auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); - auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); - if (has_reasoning_content && has_tool_calls) { - auto adjusted_message = msg; - adjusted_message["tool_plan"] = msg.at("reasoning_content"); - adjusted_message.erase("reasoning_content"); - adjusted_messages.push_back(adjusted_message); - } else { - adjusted_messages.push_back(msg); - } - } - data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); - data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; - if (string_ends_with(data.prompt, "<|START_THINKING|>")) { - if (!inputs.enable_thinking) { - data.prompt += "<|END_THINKING|>"; - } else { - data.thinking_forced_open = true; - } - } else if (!inputs.enable_thinking && string_ends_with(data.prompt, "<|CHATBOT_TOKEN|>")) { - data.prompt += "<|START_THINKING|><|END_THINKING|>"; - } - - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"tool_call_id", { - {"type", "string"}, - // Command-R's template expects an integer string. - {"pattern", "^[0-9]{1,10}$"}, - }}, - {"tool_name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"parameters", function.at("parameters")}, - }}, - {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") + - "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") + - "(<\\|START_ACTION\\|>)[\\s\\S]*" - }); - data.preserved_tokens = { - "<|START_ACTION|>", - "<|END_ACTION|>", - "<|START_RESPONSE|>", - "<|END_RESPONSE|>", - "<|START_THINKING|>", - "<|END_THINKING|>", - }; - return data; -} - -static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { - if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { - throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); - } - const auto & parameters_properties = parameters.at("properties"); - const auto & parameters_required = parameters.at("required"); - for (const auto & prop : expected_properties) { - if (!parameters_properties.contains(prop)) { - throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT - } - if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { - throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT - } - } - if (parameters_properties.size() != expected_properties.size()) { - throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", ")); - } -} - -static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { - auto builtin_tools = json::array(); - common_chat_params data; - if (!inputs.tools.is_null()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - - auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { - if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py - expect_tool_parameters(name, parameters, {"query"}); - } else if (name == "python" || name == "code_interpreter") { - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py - expect_tool_parameters(name, parameters, {"code"}); - } else { - return false; - } - - std::vector kvs; - for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT - } - - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); - builtin_tools.push_back(name); - - return true; - }; - - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime - if (allow_python_tag_builtin_tools) { - handle_builtin_tool(name, parameters); - } - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"{\" space " - "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " - " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " - " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " " - "\"}\" space")); - }); - // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", - }); - if (!builtin_tools.empty()) { - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); - } - // Allow a few empty lines on top of the usual constrained json schema space rule. - builder.add_rule("root", string_join(tool_rules, " | ")); - data.additional_stops.push_back("<|eom_id|>"); - }); - data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() - ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS - : COMMON_CHAT_FORMAT_LLAMA_3_X; - } else { - data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - } - data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { - {"date_string", format_time(inputs.now, "%d %b %Y")}, - {"tools_in_user_message", false}, - {"builtin_tools", builtin_tools}, - }); - return data; -} - -static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Generate the prompt using the apply() function with the template - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2; - - // Handle thinking tags appropriately based on inputs.enable_thinking - if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - // When tools are present, build grammar for the format, similar to CommandR, but without tool call ID - if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = true; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - { "type", "object" }, - { "properties", - { - { "name", - { - { "type", "string" }, - { "const", function.at("name") }, - } }, - { "arguments", function.at("parameters") }, - } }, - { "required", json::array({ "name", "arguments" }) }, - }); - }); - auto schema = json{ - { "type", "array" }, - { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, - { "minItems", 1 }, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - "\"\" " + builder.add_schema("tool_calls", schema) + - " \"\""); - }); - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? - "[\\s\\S]*?(\\s*)" : - "(?:[\\s\\S]*?\\s*)?") + - "()[\\s\\S]*" }); - } - return data; -} - -static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED; - - // Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking - bool supports_reasoning = (tmpl.source().find("") != std::string::npos); - - // Handle thinking tags appropriately based on inputs.enable_thinking - if (supports_reasoning && string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - data.preserved_tokens = { - "", - "", - }; - - if (supports_reasoning) { - data.preserved_tokens.insert(data.preserved_tokens.end(), {"", ""}); - } - - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); - auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; - auto include_grammar = true; - - auto parser = build_chat_peg_constructed_parser([&](auto & p) { - auto reasoning = p.eps(); - if (supports_reasoning && inputs.enable_thinking && extract_reasoning) { - auto reasoning_content = p.reasoning(p.until("")) + ("" | p.end()); - if (data.thinking_forced_open) { - reasoning = reasoning_content; - } - } - - // Response format parser - if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { - return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema)); - } - - // Tool call parser - if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { - auto tool_choice = p.choice(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - - auto schema_info = common_schema_info(); - schema_info.resolve_refs(parameters); - - auto tool_open = "\n"; - auto tool_close = p.literal("\n"); - auto args = p.sequence(); - auto arg_string = p.rule("xml-arg-string", p.until_one_of({ - "\n", - "\n" - })); - - foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) { - auto rule_name = "tool-" + name + "-arg-" + param_name; - - auto arg_open = "\n"; - auto arg_close = p.literal("\n"); - auto arg_value = p.eps(); - - if (schema_info.resolves_to_string(param_schema)) { - arg_value = p.tool_arg_string_value(arg_string) + "\n"; - } else { - arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema)); - } - - // Model may or my not close with - auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close))); - args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1); - }); - - tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close)); - }); - - auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; - auto max_calls = inputs.parallel_tool_calls ? -1 : 1; - auto tool_call = p.rule("tool-call", "\n" + tool_choice + "" + p.space()); - auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls)); - - return reasoning << p.content(p.until("")) << tool_calls; - } - - // Content only parser - include_grammar = false; - return reasoning << p.content(p.rest()); - }); - - data.parser = parser.save(); - - if (include_grammar) { - data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; - - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - auto schema = function.at("parameters"); - builder.resolve_refs(schema); - }); - parser.build_grammar(builder, data.grammar_lazy); - }); - - data.grammar_triggers = { - {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""} - }; - } - - return data; -} - - -static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Generate the prompt using the apply() function with the template - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_APERTUS; - - // Handle thinking tags appropriately based on inputs.enable_thinking - if (string_ends_with(data.prompt, "<|inner_prefix|>")) { - if (!inputs.enable_thinking) { - data.prompt += "<|inner_suffix|>"; - } else { - data.thinking_forced_open = true; - } - } - - // When tools are present, build grammar for the <|tools_prefix|> format - if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = true; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - { "type", "object" }, - { "properties", - { - { function.at("name"), function.at("parameters") } - } }, - { "required", json::array({ function.at("name") }) }, - }); - }); - auto schema = json{ - { "type", "array" }, - { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, - { "minItems", 1 }, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") + - "\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\""); - }); - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? - "[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" : - "(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") + - "(<\\|tools_prefix\\|>)[\\s\\S]*" }); - data.preserved_tokens = { - "<|system_start|>", - "<|system_end|>", - "<|developer_start|>", - "<|developer_end|>", - "<|user_start|>", - "<|user_end|>", - "<|assistant_start|>", - "<|assistant_end|>", - "<|inner_prefix|>", - "<|inner_suffix|>", - "<|tools_prefix|>", - "<|tools_suffix|>", - }; - } - return data; -} - -static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - auto prompt = apply(tmpl, inputs); - - // Hacks to fix the official (broken) prompt. - // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, - // until the official template is fixed. - if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) { - // Don't leave the chat dangling after tool results - if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) { - prompt += "<|end▁of▁sentence|>"; - if (inputs.add_generation_prompt) { - prompt += "<|Assistant|>"; - } - } - // Fix up tool call delta example added by Minja - prompt = std::regex_replace( - prompt, - std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"), - "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); - } - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; - if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "( \"<|tool▁call▁begin|>\" )? \"function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " - "\"```<|tool▁call▁end|>\"")); - }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " - "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + - "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" - }); - data.preserved_tokens = { - "", - "", - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>", - "<|tool▁sep|>", - "<|tool▁call▁end|>", - "<|tool▁calls▁end|", - }; - }); - } - return data; -} - -static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Pass thinking context for DeepSeek V3.1 template - json additional_context = { - {"thinking", inputs.enable_thinking}, - }; - - auto prompt = apply(tmpl, inputs, - /* messages_override= */ inputs.messages, - /* tools_override= */ std::nullopt, - additional_context); - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - if (string_ends_with(data.prompt, "")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "( \"<|tool▁call▁begin|>\" )? \"" + name + "<|tool▁sep|>" - "\" " + builder.add_schema(name + "-args", parameters) + " " - "\"<|tool▁call▁end|>\"")); - }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " - "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + - "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" - }); - data.preserved_tokens = { - "", - "", - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>", - "<|tool▁sep|>", - "<|tool▁call▁end|>", - "<|tool▁calls▁end|>", - }; - }); - } - return data; -} - -static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_MINIMAX_M2; - - // Handle thinking tags based on prompt ending - if (string_ends_with(data.prompt, "\n")) { - if (!params.enable_thinking) { - // Close the thinking tag immediately if thinking is disabled - data.prompt += "\n\n"; - } else { - // Mark thinking as forced open (template started with ) - data.thinking_forced_open = true; - } - } - - // Preserve MiniMax-M2 special tokens - data.preserved_tokens = { - "", - "", - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form { - /* form.scope_start = */ "\n", - /* form.tool_start = */ "\n", - /* form.key_start = */ "", - /* form.val_end = */ "\n", - /* form.tool_end = */ "\n", - /* form.scope_end = */ "", - }; - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_KIMI_K2; - - data.preserved_tokens = { - "", - "", - "<|tool_calls_section_begin|>", - "<|tool_call_begin|>", - "<|tool_call_argument_begin|>", - "<|tool_call_end|>", - "<|tool_calls_section_end|>", - "<|im_end|>", - "<|im_system|>", - "<|im_middle|>", - }; - - data.additional_stops.insert(data.additional_stops.end(), { - "<|im_end|>", - "<|im_middle|>" - }); - // build grammar for tool call - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "<|tool_calls_section_begin|>"; - form.tool_start = "<|tool_call_begin|>"; - form.tool_sep = "<|tool_call_argument_begin|>{"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}<|tool_call_end|>"; - form.scope_end = "<|tool_calls_section_end|>"; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_APRIEL_1_5; - - data.preserved_tokens = { - "", - "", - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO; - - data.preserved_tokens = { - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "\n"; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { +static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; // Copy reasoning to the "thinking" field as expected by the gpt-oss template auto adjusted_messages = json::array(); for (const auto & msg : inputs.messages) { auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); - auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); if (has_reasoning_content && has_tool_calls) { - auto adjusted_message = msg; + auto adjusted_message = msg; adjusted_message["thinking"] = msg.at("reasoning_content"); - adjusted_message.erase("content"); adjusted_messages.push_back(adjusted_message); } else { adjusted_messages.push_back(msg); } } - auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); + auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages); // Check if we need to replace the return token with end token during // inference and without generation prompt. For more details see: @@ -2049,896 +966,323 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp } } - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_GPT_OSS; + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; // These special tokens are required to parse properly, so we include them // even if parse_tool_calls is false. data.preserved_tokens = { - "<|channel|>", - "<|constrain|>", - "<|message|>", - "<|start|>", - "<|end|>", + "<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>", }; - if (!inputs.json_schema.is_null()) { - data.grammar_lazy = false; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schema = inputs.json_schema; - builder.resolve_refs(schema); + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && has_tools; - auto not_end = builder.add_rule("not-end", - "[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]"); - auto analysis = builder.add_rule("analysis", - "\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\""); - auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+"); - auto final = builder.add_rule("final", - "\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " + - builder.add_schema("response", schema) - ); + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + const std::string END = "<|end|>"; + const std::string START = "<|start|>"; + const std::string MESSAGE = "<|message|>"; + const std::string CHANNEL = "<|channel|>"; + const std::string CONSTRAIN = "<|constrain|>"; + const std::string START_ASSISTANT = START + "assistant"; + const std::string CHANNEL_ANALYSIS = CHANNEL + "analysis"; + const std::string CHANNEL_COMMENTARY = CHANNEL + "commentary"; + const std::string CHANNEL_FINAL = CHANNEL + "final"; - builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final); - }); - } + auto the_end = END | p.end(); - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - // tool calls can appear in commentary or analysis channels - auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )"); + const std::string analysis_header = CHANNEL_ANALYSIS + MESSAGE; + auto segment_content = p.until(END); + auto analysis_segment = extract_reasoning ? + p.literal(analysis_header) + p.reasoning(segment_content) + p.until(END) + the_end : + p.content(analysis_header + p.until(END) + the_end); - std::vector tool_rules_recipient_in_role; - std::vector tool_rules_recipient_in_channel; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); + auto channel_header_content = p.until_one_of({ " to=functions.", MESSAGE }); + auto content_header = p.choice({ p.literal(CHANNEL_COMMENTARY), p.literal(CHANNEL_FINAL) }); + auto content_segment = p.rule("content-segment", content_header + channel_header_content + MESSAGE + + p.content(segment_content) + the_end); - tool_rules_recipient_in_role.push_back( - builder.add_rule(name + "-call", - "\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " + - builder.add_schema(name + "-args", parameters) - ) - ); - - tool_rules_recipient_in_channel.push_back( - builder.add_rule(name + "-call", - "\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " + - builder.add_schema(name + "-args", parameters) - ) - ); - }); - - auto recipient_in_channel = builder.add_rule("recipient_in_channel", - channel + " \" to=functions.\" ( " + - string_join(tool_rules_recipient_in_channel, " | ") + " )" - ); - - if (data.grammar_lazy) { - auto recipient_in_role = builder.add_rule("recipient_in_role", - "\"<|start|>assistant\"? \" to=functions.\" ( " + - string_join(tool_rules_recipient_in_role, " | ") + " )" - ); - - builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel); - } else { - auto not_end = builder.add_rule("not-end", - "[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]"); - auto analysis = builder.add_rule("analysis", - "\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\""); - auto commentary = builder.add_rule("commentary", - "\"<|channel|>commentary<|message|>\" ( " + not_end + " )* \"<|end|>\""); - - auto recipient_in_role = builder.add_rule("recipient_in_role", - "\" to=functions.\" ( " + string_join(tool_rules_recipient_in_role, " | ") + " )" - ); - - builder.add_rule("root", - "( " + analysis + " \"<|start|>assistant\" )? " + - "( " + commentary + " \"<|start|>assistant\" )? " + - "( " + recipient_in_role + " | " + recipient_in_channel + " )" - ); - } - - // Trigger on tool calls that appear in the commentary channel - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|channel\\|>(?:commentary|analysis) to" - }); - - // Trigger tool calls that appear in the role section, either at the - // start or in the middle. - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "^ to" - }); - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|start\\|>assistant to" - }); - }); - } - - return data; -} - -static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - std::string prompt = apply(tmpl, inputs); - - // match the existing trimming behavior - if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) { - prompt.erase(0, tmpl.bos_token().size()); - } - if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) { - prompt.erase(prompt.size() - tmpl.eos_token().size()); - } - if (string_ends_with(prompt, "")) { - if (!inputs.enable_thinking) { - prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - // add GLM preserved tokens - data.preserved_tokens = { - "<|endoftext|>", - "[MASK]", - "[gMASK]", - "[sMASK]", - "", - "", - "<|system|>", - "<|user|>", - "<|assistant|>", - "<|observation|>", - "<|begin_of_image|>", - "<|end_of_image|>", - "<|begin_of_video|>", - "<|end_of_video|>", - "<|begin_of_audio|>", - "<|end_of_audio|>", - "<|begin_of_transcription|>", - "<|end_of_transcription|>", - "<|code_prefix|>", - "<|code_middle|>", - "<|code_suffix|>", - "/nothink", - "", - "", - "", - "", - "", - "", - "", - "" - }; - - // extra GLM 4.5 stop word - data.additional_stops.insert(data.additional_stops.end(), { - "<|user|>", - "<|observation|>" - }); - - // build grammar for tool call - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "\n", - /* form.tool_sep = */ "\n", - /* form.key_start = */ "", - /* form.key_val_sep = */ "\n", - /* form.val_end = */ "\n", - /* form.tool_end = */ "\n", - /* form.scope_end = */ "", - }; - build_grammar_xml_tool_call(data, inputs.tools, form); - - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_GLM_4_5; - return data; -} - -static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { - LOG_DBG("%s\n", __func__); - common_chat_params data; - const std::optional additional_context = json { - {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, - {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, - }; - data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context); - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); - }); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["}); - data.preserved_tokens = { - " functools[", - }; - data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2; - } else { - data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - } - return data; -} - -static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { - // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... - // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar - // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. - common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector first_tool_rules; - std::vector subsequent_tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - std::string args_pattern = "[\\s\\S]*"; - auto args_rule = builder.add_schema(name + "-args", parameters); - if (name == "python") { - args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*"); - } else { - args_pattern = "\\{" + args_pattern; - } - auto call_rule = builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule); - first_tool_rules.push_back(call_rule); - if (inputs.parallel_tool_calls) { - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>\" " + call_rule)); - } - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "((?:[\\s\\S]+?>>>)?" + regex_escape(name) + "\n)" + args_pattern, - }); - }); - data.preserved_tokens = { - "<|end_header_id|>", - }; - auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; - if (inputs.parallel_tool_calls) { - auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; - builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); - } else { - builder.add_rule("root", first_rule); - } - - }); - } - return data; -} - -static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { - // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt - common_chat_params data; - - if (!inputs.tools.is_null()) { - std::string python_code_argument_name; - auto has_raw_python = false; - - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - const auto & parameters = function.at("parameters"); - std::string name = function.at("name"); - if (name == "python" || name == "ipython") { - if (!parameters.contains("type")) { - throw std::runtime_error("Missing type in python tool"); - } - has_raw_python = true; - const auto & type = parameters.at("type"); - if (type == "object") { - auto properties = parameters.at("properties"); - for (auto it = properties.begin(); it != properties.end(); ++it) { - if (it.value().at("type") == "string") { - if (!python_code_argument_name.empty()) { - throw std::runtime_error("Multiple string arguments found in python tool"); - } - python_code_argument_name = it.key(); - } - } - if (python_code_argument_name.empty()) { - throw std::runtime_error("No string argument found in python tool"); - } - } else if (type != "string") { - throw std::runtime_error("Invalid type in python tool: " + type.dump()); - } - } - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); - }); - if (has_raw_python) { - tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); - } - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; - builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "\n")) { - if (!extra_context["enable_thinking"]) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (!inputs.tools.is_null()) { - // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - std::vector tool_call_alts; - std::vector escaped_names; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); - tool_call_alts.push_back(builder.add_rule( - name + "-function-tag", - "\"\" space " + - builder.add_schema(name + "-args", parameters) + " " - "\"\" space")); - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "", - }); - auto escaped_name = regex_escape(name); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - " alt_tags { - any_tool_call, - "\"\" space " + any_tool_call + " \"\"", - // The rest is just to accommodate common "good bad" outputs. - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - }; - auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); - tool_call_alts.push_back(wrappable_tool_call); - tool_call_alts.push_back( - "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); - auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); - // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "(\\s*)" : "") + ( - "\\s*(" - "(?:" - "||||)?" - "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" - ")" - ")" - ), - }); - data.preserved_tokens = { - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "```", - "```json", - "```xml", - }; - }); - } - - return data; -} - -static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Pass thinking context for Granite template - json additional_context = { - {"thinking", inputs.enable_thinking}, - }; - - data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context); - data.format = COMMON_CHAT_FORMAT_GRANITE; - - if (string_ends_with(data.prompt, "\n") || string_ends_with(data.prompt, "")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (!inputs.tools.is_null()) { - // Granite uses <|tool_call|> followed by JSON list - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name + -"-args", { - {"type", "object"}, - {"properties", { - {"name", {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - }))); - }); - - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); - auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\""); - - if (data.thinking_forced_open) { - builder.add_rule("root", "\"\" space \"\" space [^<]* \"\" space \"<|tool_call|>\" space " + tool_list); - } else { - builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list); - } - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "<|tool_call|>" - }); - - data.preserved_tokens = { - "", - "", - "", - "", - "<|tool_call|>", - }; - }); - } else { - // Handle thinking tags for non-tool responses - if (data.thinking_forced_open && inputs.enable_thinking) { - data.grammar_lazy = false; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - builder.add_rule("root", "\"\" space \"\" space .* \"\" space"); - }); - data.preserved_tokens = { - "", - "", - "", - "", - }; - } - } - - return data; -} - -static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Copy `reasoning_content` to `reasoning` - auto adjusted_messages = json::array(); - for (const auto & msg : inputs.messages) { - if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { - auto adjusted_message = msg; - adjusted_message["reasoning"] = msg.at("reasoning_content"); - adjusted_message.erase("reasoning_content"); - adjusted_messages.push_back(adjusted_message); - } else { - adjusted_messages.push_back(msg); - } - } - - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); - auto include_grammar = true; - - auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); - - // Check if we need to replace the flush token with end token during inference and without generation prompt. - if (inputs.is_inference && !inputs.add_generation_prompt) { - static constexpr std::string_view return_token = "<|flush|>"; - static constexpr std::string_view end_token = "<|end|>"; - if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) { - prompt.replace(pos, return_token.length(), end_token); - } - } - - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { - "<|think|>", - "<|content|>", - "<|begin|>", - "<|end|>", - "<|tool_calls|>", - "<|tool_call:begin|>", - "<|tool_call:end|>", - "<|tool_call:name|>", - "<|tool_call:args|>", - }; - - auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { - auto lit_think = p.atomic(p.literal("<|think|>")); - auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant")); - auto lit_content = p.atomic(p.literal("<|content|>")); - auto lit_end = p.atomic(p.literal("<|end|>")); - auto parser_until_end = p.until("<|end|>"); - - // reasoning <- "<|think|>" (!"<|end|>" .)* - auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end)); - - // content <- "<|content|>" (!"<|end|>" .)* - auto parser_content = p.rule("content", lit_content + p.content(parser_until_end)); - - // wrap_choice(items) <- item-choice wrapped* - // item-choice <- items[0] / ... / items[n] - // wrapped <- "<|end|><|begin|>assistant" item-choice - auto wrap_choice = [&](const std::vector & items) { - auto choice = p.choice(items); - return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice); - }; - - // wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ... - auto wrap_seq = [&](const std::vector & items) { - auto seq = p.sequence(); - for (auto i = 0u; i < items.size(); i++) { - if (i == 0) { - seq += items[i]; - continue; - } - seq += lit_end + lit_assistant_begin + items[i]; - } - return seq; - }; - - // Response format parser - if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { - auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema)); - return p.choice({ - wrap_seq({parser_reasoning, parser_response_format}), - wrap_seq({parser_response_format}) - }); + if (!inputs.json_schema.is_null()) { + auto final_header = p.literal(CHANNEL_FINAL); + auto constraint = p.optional(p.space() + p.literal(CONSTRAIN) + channel_header_content); + return p.optional(analysis_segment) + final_header + constraint + MESSAGE + + p.content(p.schema(p.json(), "response-format", inputs.json_schema)); } - auto lit_tool_call_begin = p.literal("<|tool_call:begin|>"); - auto lit_tool_call_name = p.literal("<|tool_call:name|>"); - auto lit_tool_call_args = p.literal("<|tool_call:args|>"); - auto lit_tool_call_end = p.literal("<|tool_call:end|>"); + auto segment = p.optional(START_ASSISTANT + p.space()) + p.choice({ content_segment, analysis_segment }); + auto contents = p.optional(segment + p.repeat(p.optional(p.space()) + segment, 0, -1)) + p.end(); // Tool call parser if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { - auto parser_tool_call = p.choice(); + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - std::string name = function.at("name"); - const auto & schema = function.at("parameters"); + std::string name = function.at("name"); + const auto & params = function.at("parameters"); - // tool(name, schema) <- name "<|tool_call:args|>" schema - parser_tool_call |= p.rule("tool-" + name, - p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args) - + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + // Tool call can appear as: + // 1. In role header: " to=functions.NAME<|channel|>..." + // 2. In channel: "<|channel|>(analysis|commentary) to=functions.NAME..." + auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name)); + + auto channel = p.literal(CHANNEL_COMMENTARY) | p.literal(CHANNEL_ANALYSIS); + auto constraint = p.space() + p.optional(p.literal(CONSTRAIN) + channel_header_content); + auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params)); + + // Pattern 1: recipient in role header + // " to=functions.NAME<|channel|>(analysis|commentary)[constraint]<|message|>ARGS" + auto tool_in_role = p.tool(p.tool_open(func_name + channel) + constraint + MESSAGE + args); + + // Pattern 2: recipient in channel header + // "<|channel|>(analysis|commentary) to=functions.NAME[constraint]<|message|>ARGS" + auto tool_in_channel = p.tool(channel + p.tool_open(func_name + constraint + MESSAGE) + args); + + tool_choice |= tool_in_role | tool_in_channel; }); auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; auto max_calls = inputs.parallel_tool_calls ? -1 : 1; - // tool-calls <- "<|tool_calls|>" tool-call+ - // tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>" - // call-id <- [a-zA-Z0-9_-]+ - // tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema) - auto parser_tool_calls = p.trigger_rule("tool-calls", - p.atomic(p.literal("<|tool_calls|>")) - + p.repeat( - p.tool_open( - lit_tool_call_begin - + p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1)) - + lit_tool_call_name - + p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args)) - + parser_tool_call - + p.tool_close(lit_tool_call_end), - /* min = */ 1, - /* max = */ max_calls)); + auto role_start = p.optional(p.space() + p.literal(START_ASSISTANT)); + auto tool_call = p.rule("tool-call", p.repeat(role_start + tool_choice, min_calls, max_calls) + p.end()); - if (min_calls == 1) { - // If required, then try any combination of the reasoning, content, and tool call - return p.choice({ - wrap_seq({parser_reasoning, parser_content, parser_tool_calls}), - wrap_seq({parser_reasoning, parser_tool_calls}), - wrap_seq({parser_content, parser_tool_calls}), - wrap_seq({parser_tool_calls}) - }); - } - - return wrap_choice({parser_reasoning, parser_content, parser_tool_calls}); + return p.choice({ p.trigger_rule("single-tool", tool_call), p.trigger_rule("tools", p.one_or_more(segment) + tool_call) }); } - // Content only parser - include_grammar = false; - return wrap_choice({parser_reasoning, parser_content}); + return contents; }); data.parser = parser.save(); if (include_grammar) { data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; - - data.grammar = build_grammar([&](const common_grammar_builder & builder) { + data.grammar = build_grammar([&](const common_grammar_builder & builder) { foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - auto schema = function.at("parameters"); + auto schema = function.at("parameters"); builder.resolve_refs(schema); }); parser.build_grammar(builder, data.grammar_lazy); }); data.grammar_triggers = { - {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"} + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" }, + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "(?:<\\|end\\|>)(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" }, + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + "(?:<\\|start\\|>assistant\\s*)?(<\\|channel\\|>(?:commentary|analysis)\\s+to=functions)" } }; } return data; } -static common_chat_params common_chat_params_init_exaone_moe(const common_chat_template & tmpl, const struct templates_params & inputs) { +// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content} +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_EXAONE_MOE; - if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += "\n\n"; - } else { - data.thinking_forced_open = true; - } - } + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { + ">>>all", + }; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + // Functionary v3.2 format: + // - Normal content: >>>all\n{content} + // - Tool calls: >>>function_name\n{json_args} + // Generation prompt ends with ">>>" so model outputs recipient immediately + + // Build content parser for >>>all\n{content} + // When tools are present, content stops before the next ">>>" (tool call) + // When no tools, content goes until end + auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>")); + auto content_until_end = p.literal(">>>all\n") + p.content(p.rest()); + + // If no tools or tool_choice is NONE, just parse content + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + // When no tools, just match the prefix and capture everything after + return content_until_end + p.end(); + } + + // Build tool call parsers for each available function + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + // Tool format: >>>function_name\n{json_args} + auto tool_parser = p.tool( + p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) + ); + + tool_choice |= p.rule("tool-" + name, tool_parser); + }); + + auto content_only = content_until_end; + auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice)); + auto content_and_tools = content_until_tool + tools_only; + + if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { + if (inputs.parallel_tool_calls) { + return p.choice({ content_and_tools, tools_only }) + p.end(); + } + return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end(); + } + if (inputs.parallel_tool_calls) { + return p.choice({ content_and_tools, content_only, tools_only }) + p.end(); + } + auto content_and_tool = content_until_tool + tool_choice; + return p.choice({ content_and_tool, content_only, tool_choice }) + p.end(); + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - // Expect: {"name": "", "arguments": {...}} - tool_rules.push_back(builder.add_rule( - name + "-call", - "\"\" space " + - builder.add_schema(name + "-obj", json{ - {"type", "object"}, - {"properties", { - {"name", json{{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - }) + - " space \"\" space")); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); }); - - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)?" : "") + - "()[\\s\\S]*" - }); - data.preserved_tokens = { - "", - "", - "", - "", - }; + parser.build_grammar(builder, data.grammar_lazy); }); + + // Grammar trigger for when the model starts outputting a tool call + // (after the initial ">>>" in the generation prompt but recipient other than "all") + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, ">>>(?!all)" } + }; } return data; } -static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) { +// Kimi K2 Thinking - uses unique tool call ID format: functions.: +// The ID contains both the function name and an incrementing counter +static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; - // This template does not support tools or reasoning - // we just need to transform the messages into the correct schema + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.preserved_tokens = { + "<|tool_calls_section_begin|>", + "<|tool_calls_section_end|>", + "<|tool_call_begin|>", + "<|tool_call_argument_begin|>", + "<|tool_call_end|>", + "", + "", + }; - templates_params inputs_new = inputs; - json & messages = inputs_new.messages; + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; - // default to chat_template_kwargs, or en-GB if not specified - std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB"); - std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB"); + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + // Kimi K2 Thinking format: + // - Reasoning: {reasoning} + // - Content: text after reasoning + // - Tool calls section: + // <|tool_calls_section_begin|> + // <|tool_call_begin|>functions.:<|tool_call_argument_begin|>{json_args}<|tool_call_end|> + // ... + // <|tool_calls_section_end|> + // The ID format is: functions.: where counter is 0, 1, 2, ... - GGML_ASSERT(messages.is_array()); - for (auto & message : messages) { - if (message.contains("role") && message["role"].get() != "user") { - continue; + // Tool call markers + const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>"; + const std::string SECTION_END = "<|tool_calls_section_end|>"; + const std::string CALL_BEGIN = "<|tool_call_begin|>"; + const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>"; + const std::string CALL_END = "<|tool_call_end|>"; + + const std::string THINK_START = ""; + const std::string THINK_END = ""; + + auto end = p.end(); + + // Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny. + // For example, it can call tools at the end of reasoning without closing reasoning... + auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( + p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + + p.optional(p.literal(THINK_END))) : p.eps(); + + + // Content only parser (no tools) + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return reasoning + p.content(p.rest()) + end; } - if (!message.contains("content")) { - message["content"] = json::array(); - } - if (message.contains("content") && !message["content"].is_array()) { - auto content_str = message["content"].get(); - // default to en-GB if not specified (to make common_chat_format_example works) - auto src_lang = message.contains("source_lang_code") - ? message["source_lang_code"].get() : default_src_lang; - auto tgt_lang = message.contains("target_lang_code") - ? message["target_lang_code"].get() : default_tgt_lang; - message["content"] = json::array({ - json{ - {"type", "text"}, - {"text", content_str}, - {"source_lang_code", src_lang}, - {"target_lang_code", tgt_lang}, - } - }); - } - } - data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt); - data.format = COMMON_CHAT_FORMAT_GENERIC; + // Build tool call parsers for each available function + // The ID format is: functions.: + // We need to match: functions.: + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); - return data; -} + // Match: functions.: + // Capture the full call id (functions.:) using tool_id tag + auto tool_id = p.tool_id(p.literal("functions.") + p.tool_name(p.literal(name)) + p.literal(":") + p.chars("[0-9]", 1, -1)); + auto tool_parser = p.tool( + p.tool_open(tool_id + p.literal(ARGS_BEGIN)) + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) + + p.tool_close(p.optional((p.literal(CALL_END)))) + ); -static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - data.grammar_lazy = false; - if (!inputs.json_schema.is_null()) { - if (!inputs.grammar.empty()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } - data.grammar = json_schema_to_grammar(inputs.json_schema); - } else { - data.grammar = inputs.grammar; - } - return data; -} + tool_choice |= p.rule("tool-" + name, tool_parser); + }); -static common_chat_params common_chat_params_init_seed_oss( - const common_chat_template & tmpl, - templates_params & params, - const common_chat_templates_inputs & inputs) -{ - common_chat_params data; - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_SEED_OSS; - if (string_ends_with(data.prompt, "")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } + // Tool calls section: <|tool_calls_section_begin|> tool_calls <|tool_calls_section_end|> + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + // Use trigger_rule so grammar generator knows where to start generating rules + auto tool_calls = p.rule("tool-calls", + p.optional(p.literal(SECTION_BEGIN)) + + p.trigger_rule("tool-call", p.repeat(CALL_BEGIN + tool_choice, min_calls, max_calls) + + p.optional(p.literal(SECTION_END))) + ); - if (params.tools.is_array() && !params.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN })); + + return reasoning + content_before_tools + tool_calls + end; + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(params.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - // Create rule for Seed-OSS function call format - std::string param_rules; - if (parameters.contains("properties")) { - for (const auto & [key, value] : parameters.at("properties").items()) { - param_rules += "\"\"" + builder.add_schema(name + "-arg-" + key, value) + - "\"\""; - } - } - - tool_rules.push_back(builder.add_rule(name + "-call", - "\"\" space \"\" space " + - param_rules + - " \"\" space \"\"")); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); }); - - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "" }); - - data.preserved_tokens = { - "", "", "", "", - "", "", - }; - - builder.add_rule("root", string_join(tool_rules, " | ")); + parser.build_grammar(builder, data.grammar_lazy); }); + + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call_begin|>" } + }; } + return data; } -// various workarounds for known issues with certain templates or model behaviors -// TODO @ngxson : improve this (how?) namespace workaround { // if first message is system and template does not support it, merge it with next message @@ -2958,6 +1302,15 @@ static void system_message_not_supported(json & messages) { } } +static void requires_non_null_content(json & messages) { + GGML_ASSERT(messages.is_array()); + for (auto & message : messages) { + if (message.contains("tool_calls") && !message.contains("content")) { + message["content"] = ""; + } + } +} + static void func_args_not_string(json & messages) { GGML_ASSERT(messages.is_array()); for (auto & message : messages) { @@ -2978,71 +1331,21 @@ static void func_args_not_string(json & messages) { } } -static void move_tool_calls_to_content(json & messages, int indent_spaces = 2) { - GGML_ASSERT(messages.is_array()); - for (auto & message : messages) { - if (message.contains("tool_calls")) { - auto tool_calls_new = json{ - {"tool_calls", message.at("tool_calls")} - }; - message.erase("tool_calls"); - auto content = message.at("content"); - std::string content_new = content.is_null() ? "" : content.get(); - message["content"] = content_new + tool_calls_new.dump(indent_spaces, ' ', false, json::error_handler_t::replace); - } - } } -// TODO @ngxson : we may remove support for generic schema in the future -static void use_generic_schema(json & messages) { - GGML_ASSERT(messages.is_array()); - for (auto & message : messages) { - if (message.contains("tool_calls") && message.at("tool_calls").is_array()) { - auto & tool_calls = message.at("tool_calls"); - for (auto & tool_call : tool_calls) { - if (tool_call.contains("type") && tool_call.at("type") == "function" && - tool_call.contains("function") && tool_call.at("function").is_object()) { - // Copy values before erasing to avoid use-after-free - json name_value; - json arguments_value; - json id_value; - const auto & function = tool_call.at("function"); - if (function.contains("name")) { - name_value = function.at("name"); - } - if (function.contains("arguments")) { - arguments_value = function.at("arguments"); - } - if (tool_call.contains("id")) { - id_value = tool_call.at("id"); - } - // Now safely erase and assign in the correct order - tool_call.erase("type"); - tool_call.erase("function"); - tool_call.erase("id"); - // Reassign in desired order: name, arguments, id - if (!name_value.is_null()) { - tool_call["name"] = name_value; - } - if (!arguments_value.is_null()) { - tool_call["arguments"] = arguments_value; - } - if (!id_value.is_null()) { - tool_call["id"] = id_value; - } - } - } - } - } +static json common_chat_extra_context() { + json ctx = json::object(); + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::string datetime_str = format_time(now, "%b %d %Y"); + std::string date_str = format_time(now, "%d %b %Y"); + ctx["datetime"] = datetime_str; + ctx["date_string"] = date_str; + return ctx; } -} // namespace workaround - -static common_chat_params common_chat_templates_apply_jinja( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) -{ - templates_params params; +static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { + autoparser::templates_params params; params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use @@ -3063,7 +1366,14 @@ static common_chat_params common_chat_templates_apply_jinja( workaround::system_message_not_supported(params.messages); } - params.extra_context = json::object(); + if (tmpl.original_caps().supports_tool_calls) { + // some templates will require the content field in tool call messages + // to still be non-null, this puts an empty string everywhere where the + // content field is null + workaround::requires_non_null_content(params.messages); + } + + params.extra_context = common_chat_extra_context(); for (auto el : inputs.chat_template_kwargs) { params.extra_context[el.first] = json::parse(el.second); } @@ -3072,229 +1382,71 @@ static common_chat_params common_chat_templates_apply_jinja( params.json_schema = json::parse(inputs.json_schema); } - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - params.parallel_tool_calls = false; - } else { - params.parallel_tool_calls = inputs.parallel_tool_calls; - } + // if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + // LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + // params.parallel_tool_calls = false; + // } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + //} if (params.tools.is_array()) { if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { throw std::runtime_error("Cannot specify grammar with tools"); } if (caps.supports_tool_calls && !caps.supports_tools) { - LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + LOG_WRN( + "Template supports tool calls but does not natively describe tools. The fallback behaviour used may " + "produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); } } - // DeepSeek V3.1: detect based on specific patterns in the template - if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos && - params.json_schema.is_null()) { - return common_chat_params_init_deepseek_v3_1(tmpl, params); - } - - // DeepSeek R1: use handler in all cases except json schema (thinking / tools). - if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { - return common_chat_params_init_deepseek_r1(tmpl, params); - } - - // Command R7B: : use handler in all cases except json schema (thinking / tools). - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_command_r7b(tmpl, params); - } - - // Granite (IBM) - detects thinking / tools support - if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) { - workaround::func_args_not_string(params.messages); - workaround::use_generic_schema(params.messages); - workaround::move_tool_calls_to_content(params.messages); - return common_chat_params_init_granite(tmpl, params); - } - - // GLM 4.5: detect by and tags (check before Hermes since both use ) - if (src.find("[gMASK]") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - params.json_schema.is_null()) { - workaround::func_args_not_string(params.messages); - if (!params.extra_context.contains("clear_thinking")) { - // by default, do not clear reasoning_content (added since GLM-4.7) - params.extra_context["clear_thinking"] = false; - } - return common_chat_params_init_glm_4_5(tmpl, params); - } - - // Qwen3-Coder XML format detection (must come before Hermes 2 Pro) - // Detect via XML markers: , , and blocks. - // Also matches Step-3.5-Flash and Nemotron 3 Nano which use the same output format. - if (src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("# Tools") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos) { - return common_chat_params_init_xiaomi_mimo(tmpl, params); - } - - // EXAONE MoE format detection - if (src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("<|tool_declare|>") != std::string::npos) { - return common_chat_params_init_exaone_moe(tmpl, params); - } - - // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos && params.json_schema.is_null()) { - return common_chat_params_init_hermes_2_pro(tmpl, params); - } - - // GPT-OSS - if (src.find("<|channel|>") != std::string::npos) { - return common_chat_params_init_gpt_oss(tmpl, params); - } - - // Seed-OSS - if (src.find("") != std::string::npos) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_seed_oss(tmpl, params, inputs); - } - - // Nemotron v2 - if (src.find("") != std::string::npos) { - return common_chat_params_init_nemotron_v2(tmpl, params); - } - - // Apertus format detection - if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) { - return common_chat_params_init_apertus(tmpl, params); - } - - // LFM2 (w/ tools) - if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && - src.find("]<|tool_list_end|>") != std::string::npos) { - return common_chat_params_init_lfm2(tmpl, params); - } - - // MiniMax-M2 format detection - if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_minimax_m2(tmpl, params); - } - - // Kimi K2 format detection - if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos && - src.find("<|tool_calls_section_begin|>") != std::string::npos && - src.find("## Return of") != std::string::npos) { - return common_chat_params_init_kimi_k2(tmpl, params); - } - - // Apriel 1.5 format detection - if (src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("<|assistant|>") != std::string::npos && - src.find("<|tool_result|>") != std::string::npos && - src.find("[") != std::string::npos && - src.find("]") != std::string::npos) { - return common_chat_params_init_apriel_1_5(tmpl, params); - } - - // Solar Open - if (src.find("<|tool_response:begin|>") != std::string::npos && - src.find("<|tool_response:name|>") != std::string::npos && - src.find("<|tool_response:result|>") != std::string::npos) { - return common_chat_params_init_solar_open(tmpl, params); - } - - // Use generic handler when mixing tools + JSON schema. - // TODO: support that mix in handlers below. - if ((params.tools.is_array() && params.json_schema.is_object())) { - return common_chat_params_init_generic(tmpl, params); - } - - // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. - if (src.find(">>>all") != std::string::npos) { - return common_chat_params_init_functionary_v3_2(tmpl, params); - } - - // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. - if (src.find(" functools[") != std::string::npos) { - return common_chat_params_init_firefunction_v2(tmpl, params); - } - - // Functionary v3.1 (w/ tools) - if (src.find("<|start_header_id|>") != std::string::npos - && src.find("ipython<|end_header_id|>") != std::string::npos) { - auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - workaround::func_args_not_string(params.messages); - return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); - } - - // Ministral/Mistral Large 3 - if (src.find("[SYSTEM_PROMPT]") != std::string::npos && - src.find("[TOOL_CALLS]") != std::string::npos && - src.find("[ARGS]") != std::string::npos) { + // Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser + // Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them + if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos && + src.find("[ARGS]") != std::string::npos && src.find("[CALL_ID]") == std::string::npos) { + LOG_DBG("Using specialized template: Ministral/Magistral Large 3\n"); return common_chat_params_init_ministral_3(tmpl, params); } - if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { - return common_chat_params_init_magistral(tmpl, params); + // GPT-OSS - has unique channel-based structure that needs dedicated handler + if (src.find("<|channel|>") != std::string::npos) { + LOG_DBG("Using specialized template: GPT-OSS\n"); + return common_chat_params_init_gpt_oss(tmpl, params); } - // Solar Open - if (src.find("<|tool_response:begin|>") != std::string::npos && - src.find("<|tool_response:name|>") != std::string::npos && - src.find("<|tool_response:result|>") != std::string::npos) { - return common_chat_params_init_solar_open(tmpl, params); + // Functionary v3.2 - uses recipient-based format with >>>recipient\n{content} + // Detection: template has ">>>all" for content and ">>>" prefix for tool calls + if (src.find(">>>all") != std::string::npos && src.find(">>>${recipient}") != std::string::npos) { + LOG_DBG("Using specialized template: Functionary v3.2\n"); + return common_chat_params_init_functionary_v3_2(tmpl, params); } - // TranslateGemma - if (src.find("[source_lang_code]") != std::string::npos && - src.find("[target_lang_code]") != std::string::npos) { - return common_chat_params_init_translate_gemma(tmpl, params); + // Kimi K2 Thinking - uses unique tool call ID format: functions.: + // Detection: template has "<|tool_calls_section_begin|>" and "functions." prefix in tool call IDs + if (src.find("<|tool_calls_section_begin|>") != std::string::npos && + src.find("<|tool_call_begin|>") != std::string::npos) { + LOG_DBG("Using specialized template: Kimi K2 Thinking\n"); + return common_chat_params_init_kimi_k2(tmpl, params); } - // Plain handler (no tools) - if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return common_chat_params_init_without_tools(tmpl, params); + try { + LOG_DBG("Using differential autoparser\n"); + struct autoparser::autoparser autoparser; + autoparser.analyze_template(tmpl); + auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); + auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; + return auto_params; + } catch (const std::exception & e) { + throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); } - - // Mistral Nemo (w/ tools) - if (src.find("[TOOL_CALLS]") != std::string::npos) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_mistral_nemo(tmpl, params); - } - - // Generic fallback - workaround::func_args_not_string(params.messages); - workaround::use_generic_schema(params.messages); - workaround::move_tool_calls_to_content(params.messages); - return common_chat_params_init_generic(tmpl, params); } // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. -static common_chat_params common_chat_templates_apply_legacy( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) -{ - size_t alloc_size = 0; +static common_chat_params common_chat_templates_apply_legacy(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { + size_t alloc_size = 0; std::vector chat; - std::vector contents; + std::vector contents; for (const auto & msg : inputs.messages) { auto content = msg.content; @@ -3304,25 +1456,27 @@ static common_chat_params common_chat_templates_apply_legacy( continue; } if (!content.empty()) { - content += "\n";; + content += "\n"; + ; } content += part.text; } contents.emplace_back(std::move(content)); } for (size_t i = 0; i < contents.size(); ++i) { - const auto & msg = inputs.messages[i]; + const auto & msg = inputs.messages[i]; const auto & content = contents[i]; - chat.push_back({msg.role.c_str(), content.c_str()}); + chat.push_back({ msg.role.c_str(), content.c_str() }); size_t msg_size = msg.role.size() + content.size(); - alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops + alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops } std::vector buf(alloc_size); // run the first time to get the total output length const auto & src = tmpls->template_default->source(); - int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, + buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -3334,7 +1488,8 @@ static common_chat_params common_chat_templates_apply_legacy( // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), + buf.size()); } // for safety, we check the result again @@ -3352,14 +1507,72 @@ static common_chat_params common_chat_templates_apply_legacy( return params; } -common_chat_params common_chat_templates_apply( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) -{ +common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { GGML_ASSERT(tmpls != nullptr); - return inputs.use_jinja - ? common_chat_templates_apply_jinja(tmpls, inputs) - : common_chat_templates_apply_legacy(tmpls, inputs); + return inputs.use_jinja ? common_chat_templates_apply_jinja(tmpls, inputs) : + common_chat_templates_apply_legacy(tmpls, inputs); +} + +common_chat_msg common_chat_parse(const std::string & input, + bool is_partial, + const common_chat_parser_params & params) { + return common_chat_peg_parse(params.parser, input, is_partial, params); +} + +common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, + const std::string & input, + bool is_partial, + const common_chat_parser_params & params) { + const common_peg_arena & parser = src_parser.empty() ? + build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) : + src_parser; + + if (src_parser.empty()) { + LOG_WRN("No parser definition detected, assuming pure content parser."); + } + + LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str()); + + common_peg_parse_context ctx(input, is_partial); + ctx.debug = params.debug; + auto result = parser.parse(ctx); + + if (result.fail()) { + // During partial parsing, return partial results if any AST nodes were captured + // This allows streaming to work correctly for formats like FUNC_MARKDOWN_CODE_BLOCK + if (is_partial && result.end > 0) { + // Try to extract any partial results from what was successfully parsed + common_chat_msg msg; + msg.role = "assistant"; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (ctx.debug) { + fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str()); + fflush(stderr); + } + return msg; + } + throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " + + input.substr(result.end)); + } + + common_chat_msg msg; + msg.role = "assistant"; + + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (ctx.debug) { + fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str()); + fflush(stderr); + } + + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({ msg }).at(0).dump().c_str()); + } + return msg; } std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates) { @@ -3367,3 +1580,4 @@ std::map common_chat_templates_get_caps(const common_chat_tem GGML_ASSERT(chat_templates->template_default != nullptr); return chat_templates->template_default->caps.to_map(); } + diff --git a/common/chat.h b/common/chat.h index 6f0b9409e..005cc5c8b 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,17 +3,30 @@ #pragma once #include "common.h" +#include "jinja/parser.h" +#include "nlohmann/json_fwd.hpp" #include "peg-parser.h" -#include +#include "jinja/runtime.h" +#include "jinja/caps.h" +#include "nlohmann/json.hpp" + #include +#include +#include #include #include -#include + +using chat_template_caps = jinja::caps; +using json = nlohmann::ordered_json; #include struct common_chat_templates; +namespace autoparser { +struct templates_params; +} // namespace autoparser + struct common_chat_tool_call { std::string name; std::string arguments; @@ -38,21 +51,85 @@ struct common_chat_msg_content_part { } }; +struct common_chat_template { + jinja::program prog; + std::string bos_tok; + std::string eos_tok; + std::string src; + chat_template_caps caps; + + common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { + jinja::lexer lexer; + auto lexer_res = lexer.tokenize(src); + this->prog = jinja::parse_from_tokens(lexer_res); + + this->src = lexer_res.source; + this->bos_tok = bos_token; + this->eos_tok = eos_token; + + this->caps = jinja::caps_get(prog); + // LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str()); + } + + const std::string & source() const { return src; } + const std::string & bos_token() const { return bos_tok; } + const std::string & eos_token() const { return eos_tok; } + + // TODO: this is ugly, refactor it somehow + json add_system(const json & messages, const std::string & system_prompt) const { + GGML_ASSERT(messages.is_array()); + auto msgs_copy = messages; + if (!caps.supports_system_role) { + if (msgs_copy.empty()) { + msgs_copy.insert(msgs_copy.begin(), json{ + {"role", "user"}, + {"content", system_prompt} + }); + } else { + auto & first_msg = msgs_copy[0]; + if (!first_msg.contains("content")) { + first_msg["content"] = ""; + } + first_msg["content"] = system_prompt + "\n\n" + + first_msg["content"].get(); + } + } else { + if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") { + msgs_copy.insert(msgs_copy.begin(), json{ + {"role", "system"}, + {"content", system_prompt} + }); + } else if (msgs_copy[0].at("role") == "system") { + msgs_copy[0]["content"] = system_prompt; + } + } + return msgs_copy; + } + + chat_template_caps original_caps() const { + return caps; + } + +}; + struct common_chat_msg { - std::string role; - std::string content; + std::string role; + std::string content; std::vector content_parts; - std::vector tool_calls; - std::string reasoning_content; - std::string tool_name; - std::string tool_call_id; + std::vector tool_calls; + std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const; bool empty() const { - return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && + tool_name.empty() && tool_call_id.empty(); } - void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { + + void set_tool_call_ids(std::vector & ids_cache, + const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { auto id = tool_calls[i].id; @@ -64,32 +141,28 @@ struct common_chat_msg { tool_calls[i].id = ids_cache[i]; } } + bool operator==(const common_chat_msg & other) const { - return role == other.role - && content == other.content - && content_parts == other.content_parts - && tool_calls == other.tool_calls - && reasoning_content == other.reasoning_content - && tool_name == other.tool_name - && tool_call_id == other.tool_call_id; - } - bool operator!=(const common_chat_msg & other) const { - return !(*this == other); + return role == other.role && content == other.content && content_parts == other.content_parts && + tool_calls == other.tool_calls && reasoning_content == other.reasoning_content && + tool_name == other.tool_name && tool_call_id == other.tool_call_id; } + + bool operator!=(const common_chat_msg & other) const { return !(*this == other); } }; struct common_chat_msg_diff { - std::string reasoning_content_delta; - std::string content_delta; - size_t tool_call_index = std::string::npos; + std::string reasoning_content_delta; + std::string content_delta; + size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; - static std::vector compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new); + static std::vector compute_diffs(const common_chat_msg & msg_prv, + const common_chat_msg & msg_new); bool operator==(const common_chat_msg_diff & other) const { - return content_delta == other.content_delta - && tool_call_index == other.tool_call_index - && tool_call_delta == other.tool_call_delta; + return content_delta == other.content_delta && tool_call_index == other.tool_call_index && + tool_call_delta == other.tool_call_delta; } }; @@ -107,64 +180,39 @@ enum common_chat_tool_choice { enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, - COMMON_CHAT_FORMAT_GENERIC, - COMMON_CHAT_FORMAT_MISTRAL_NEMO, - COMMON_CHAT_FORMAT_MAGISTRAL, - COMMON_CHAT_FORMAT_LLAMA_3_X, - COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - COMMON_CHAT_FORMAT_DEEPSEEK_R1, - COMMON_CHAT_FORMAT_FIREFUNCTION_V2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - COMMON_CHAT_FORMAT_HERMES_2_PRO, - COMMON_CHAT_FORMAT_COMMAND_R7B, - COMMON_CHAT_FORMAT_GRANITE, - COMMON_CHAT_FORMAT_GPT_OSS, - COMMON_CHAT_FORMAT_SEED_OSS, - COMMON_CHAT_FORMAT_NEMOTRON_V2, - COMMON_CHAT_FORMAT_APERTUS, - COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, - COMMON_CHAT_FORMAT_GLM_4_5, - COMMON_CHAT_FORMAT_MINIMAX_M2, - COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_CHAT_FORMAT_APRIEL_1_5, - COMMON_CHAT_FORMAT_XIAOMI_MIMO, - COMMON_CHAT_FORMAT_SOLAR_OPEN, - COMMON_CHAT_FORMAT_EXAONE_MOE, // These are intended to be parsed by the PEG parser COMMON_CHAT_FORMAT_PEG_SIMPLE, COMMON_CHAT_FORMAT_PEG_NATIVE, - COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, - COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; struct common_chat_templates_inputs { - std::vector messages; - std::string grammar; - std::string json_schema; - bool add_generation_prompt = true; - bool use_jinja = true; + std::vector messages; + std::string grammar; + std::string json_schema; + bool add_generation_prompt = true; + bool use_jinja = true; // Parameters below only supported when use_jinja is true - std::vector tools; - common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; - bool parallel_tool_calls = false; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking" - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - std::map chat_template_kwargs; - bool add_bos = false; - bool add_eos = false; + std::vector tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool parallel_tool_calls = false; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking" + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::map chat_template_kwargs; + bool add_bos = false; + bool add_eos = false; }; struct common_chat_params { common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; std::string prompt; std::string grammar; - bool grammar_lazy = false; + bool grammar_lazy = false; bool thinking_forced_open = false; + bool supports_thinking = false; std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; @@ -174,13 +222,14 @@ struct common_chat_params { // per-message parsing syntax // should be derived from common_chat_params struct common_chat_parser_params { - common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) - bool reasoning_in_content = false; - bool thinking_forced_open = false; - bool parse_tool_calls = true; - common_peg_arena parser = {}; + bool reasoning_in_content = false; + bool thinking_forced_open = false; + bool parse_tool_calls = true; + bool debug = false; // Enable debug output for PEG parser + common_peg_arena parser = {}; common_chat_parser_params() = default; common_chat_parser_params(const common_chat_params & chat_params) { format = chat_params.format; @@ -193,45 +242,42 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); void common_chat_templates_free(struct common_chat_templates * tmpls); -struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; +struct common_chat_templates_deleter { + void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } +}; typedef std::unique_ptr common_chat_templates_ptr; -common_chat_templates_ptr common_chat_templates_init( - const struct llama_model * model, - const std::string & chat_template_override, - const std::string & bos_token_override = "", - const std::string & eos_token_override = ""); +common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override = "", + const std::string & eos_token_override = ""); bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = ""); - -struct common_chat_params common_chat_templates_apply( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs); +struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs); // Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single( - const struct common_chat_templates * tmpls, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja); +std::string common_chat_format_single(const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); // Returns an example of formatted chat -std::string common_chat_format_example( - const struct common_chat_templates * tmpls, - bool use_jinja, - const std::map & chat_template_kwargs); +std::string common_chat_format_example(const struct common_chat_templates * tmpls, + bool use_jinja, + const std::map & chat_template_kwargs); -const char* common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax); -common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax); +const char * common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params); +common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params); // used by arg and server -const char * common_reasoning_format_name(common_reasoning_format format); -common_reasoning_format common_reasoning_format_from_name(const std::string & format); +const char * common_reasoning_format_name(common_reasoning_format format); +common_reasoning_format common_reasoning_format_from_name(const std::string & format); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); @@ -250,3 +296,10 @@ nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_ // get template caps, useful for reporting to server /props endpoint std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates); + +std::string common_chat_template_direct_apply( + const common_chat_template & tmpl, + const autoparser::templates_params & inputs, + const std::optional & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt); diff --git a/common/common.cpp b/common/common.cpp index 3687f6b57..3e36f9691 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -683,7 +683,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { size_t offset = 0; while (offset < filename.size()) { - utf8_parse_result result = parse_utf8_codepoint(filename, offset); + utf8_parse_result result = common_parse_utf8_codepoint(filename, offset); if (result.status != utf8_parse_result::SUCCESS) { return false; diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index dbaaed500..1158d5e5d 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -1,3 +1,4 @@ +#include "log.h" #include "value.h" #include "runtime.h" #include "caps.h" @@ -36,12 +37,16 @@ static void caps_try_execute(jinja::program & prog, auto tools = ctx.get_val("tools"); bool success = false; + std::string result; try { jinja::runtime runtime(ctx); - runtime.execute(prog); + auto results = runtime.execute(prog); + auto parts = jinja::runtime::gather_string_parts(results); + result = parts->as_string().str(); success = true; } catch (const std::exception & e) { JJ_DEBUG("Exception during execution: %s", e.what()); + result = ""; // ignore exceptions during capability analysis } @@ -90,6 +95,8 @@ caps caps_get(jinja::program & prog) { return v->stats.ops.find(op_name) != v->stats.ops.end(); }; + JJ_DEBUG("%s\n", ">>> Running capability check: typed content"); + // case: typed content support caps_try_execute( prog, @@ -120,6 +127,7 @@ caps caps_get(jinja::program & prog) { } ); + JJ_DEBUG("%s\n", ">>> Running capability check: system prompt"); // case: system prompt support caps_try_execute( @@ -150,7 +158,9 @@ caps caps_get(jinja::program & prog) { } ); - // case: tools support + JJ_DEBUG("%s\n", ">>> Running capability check: single tool support"); + + // case: tools support: single call caps_try_execute( prog, [&]() { @@ -162,10 +172,10 @@ caps caps_get(jinja::program & prog) { }, { {"role", "assistant"}, - {"content", "Assistant message"}, + {"content", ""}, // Some templates expect content to be empty with tool calls {"tool_calls", json::array({ { - {"id", "call1"}, + {"id", "call00001"}, {"type", "function"}, {"function", { {"name", "tool1"}, @@ -173,19 +183,18 @@ caps caps_get(jinja::program & prog) { {"arg", "value"} }} }} - }, - { - {"id", "call2"}, - {"type", "function"}, - {"function", { - {"name", "tool2"}, - {"arguments", { - {"arg", "value"} - }} - }} } })} }, + { + {"role", "tool"}, + {"content", "Tool response"}, + {"tool_call_id", "call00001"} + }, + { + {"role", "assistant"}, + {"content", "The tool response was 'tool response'"} + }, { {"role", "user"}, {"content", "User message"}, @@ -199,7 +208,7 @@ caps caps_get(jinja::program & prog) { {"name", "tool"}, {"type", "function"}, {"function", { - {"name", "tool"}, + {"name", "tool1"}, {"description", "Tool description"}, {"parameters", { {"type", "object"}, @@ -224,6 +233,7 @@ caps caps_get(jinja::program & prog) { auto & tool_name = tools->at(0)->at("function")->at("name"); caps_print_stats(tool_name, "tools[0].function.name"); + caps_print_stats(tools, "tools"); if (!tool_name->stats.used) { result.supports_tools = false; } @@ -233,6 +243,93 @@ caps caps_get(jinja::program & prog) { if (!tool_calls->stats.used) { result.supports_tool_calls = false; } + } + ); + + JJ_DEBUG("%s\n", ">>> Running capability check: parallel tool support"); + + // case: tools support: parallel calls + caps_try_execute( + prog, + [&]() { + // messages + return json::array({ + { + {"role", "user"}, + {"content", "User message"}, + }, + { + {"role", "assistant"}, + {"content", ""}, // Some templates expect content to be empty with tool calls + {"tool_calls", json::array({ + { + {"id", "call00001"}, + {"type", "function"}, + {"function", { + {"name", "tool1"}, + {"arguments", { + {"arg", "value"} + }} + }} + }, + { + {"id", "call00002"}, + {"type", "function"}, + {"function", { + {"name", "tool1"}, + {"arguments", { + {"arg", "value"} + }} + }} + } + })} + }, + { + {"role", "tool"}, + {"content", "Tool response"}, + {"tool_call_id", "call00001"} + }, + { + {"role", "assistant"}, + {"content", "The tool response was 'tool response'"} + }, + { + {"role", "user"}, + {"content", "User message"}, + }, + }); + }, + [&]() { + // tools + return json::array({ + { + {"name", "tool"}, + {"type", "function"}, + {"function", { + {"name", "tool1"}, + {"description", "Tool description"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"arg", { + {"type", "string"}, + {"description", "Arg description"}, + }}, + }}, + {"required", json::array({ "arg" })}, + }}, + }}, + }, + }); + }, + [&](bool success, value & messages, value & /*tools*/) { + if (!success) { + result.supports_parallel_tool_calls = false; + return; + } + + auto & tool_calls = messages->at(1)->at("tool_calls");; + caps_print_stats(tool_calls, "messages[1].tool_calls"); // check for second tool call usage auto & tool_call_1 = tool_calls->at(1)->at("function"); @@ -243,6 +340,8 @@ caps caps_get(jinja::program & prog) { } ); + JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning"); + // case: preserve reasoning content in chat history caps_try_execute( prog, diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 5757c76b7..af2282c54 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -114,8 +114,10 @@ value binary_expression::execute_impl(context & ctx) { // Logical operators if (op.value == "and") { + JJ_DEBUG("Executing logical test: %s AND %s", left->type().c_str(), right->type().c_str()); return left_val->as_bool() ? right->execute(ctx) : std::move(left_val); } else if (op.value == "or") { + JJ_DEBUG("Executing logical test: %s OR %s", left->type().c_str(), right->type().c_str()); return left_val->as_bool() ? std::move(left_val) : right->execute(ctx); } @@ -838,7 +840,7 @@ value call_expression::execute_impl(context & ctx) { for (auto & arg_stmt : this->args) { auto arg_val = arg_stmt->execute(ctx); JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); - args.push_back(std::move(arg_val)); + args.push_back(arg_val); } // execute callee value callee_val = callee->execute(ctx); diff --git a/common/jinja/value.h b/common/jinja/value.h index 07e447ff6..6cbedefd9 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -12,8 +12,8 @@ #include #include #include -#include #include +#include namespace jinja { diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 2f67c74d7..27f13f034 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -27,11 +27,11 @@ static std::string build_repetition(const std::string & item_rule, int min_items if (separator_rule.empty()) { if (min_items == 1 && !has_max) { return item_rule + "+"; - } else if (min_items == 0 && !has_max) { - return item_rule + "*"; - } else { - return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } + if (min_items == 0 && !has_max) { + return item_rule + "*"; + } + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); @@ -41,7 +41,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { +static void build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { auto has_min = min_value != std::numeric_limits::min(); auto has_max = max_value != std::numeric_limits::max(); @@ -128,14 +128,14 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string if (has_min && has_max) { if (min_value < 0 && max_value < 0) { out << "\"-\" ("; - _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); + build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); out << ")"; return; } if (min_value < 0) { out << "\"-\" ("; - _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); + build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); out << ") | "; min_value = 0; } @@ -159,7 +159,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string if (has_min) { if (min_value < 0) { out << "\"-\" ("; - _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); out << ") | [0] | [1-9] "; more_digits(0, decimals_left - 1); } else if (min_value == 0) { @@ -194,7 +194,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string } digit_range(c, c); out << " ("; - _build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); out << ")"; if (c < '9') { out << " | "; @@ -213,10 +213,10 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string more_digits(0, less_decimals); out << " | "; } - _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); + build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); } else { out << "\"-\" ("; - _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); out << ")"; } return; @@ -232,7 +232,7 @@ struct BuiltinRule { std::vector deps; }; -std::unordered_map PRIMITIVE_RULES = { +static std::unordered_map PRIMITIVE_RULES = { {"boolean", {"(\"true\" | \"false\") space", {}}}, {"decimal-part", {"[0-9]{1,16}", {}}}, {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, @@ -247,7 +247,7 @@ std::unordered_map PRIMITIVE_RULES = { {"null", {"\"null\" space", {}}}, }; -std::unordered_map STRING_FORMAT_RULES = { +static std::unordered_map STRING_FORMAT_RULES = { {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, @@ -260,22 +260,26 @@ static bool is_reserved_name(const std::string & name) { static const std::unordered_set RESERVED_NAMES = [] { std::unordered_set s; s.insert("root"); - for (const auto & p : PRIMITIVE_RULES) s.insert(p.first); - for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first); + for (const auto & p : PRIMITIVE_RULES) { + s.insert(p.first); + } + for (const auto & p : STRING_FORMAT_RULES) { + s.insert(p.first); + } return s; }(); return RESERVED_NAMES.find(name) != RESERVED_NAMES.end(); } -std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); -std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); -std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); -std::unordered_map GRAMMAR_LITERAL_ESCAPES = { +static std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); +static std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); +static std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); +static std::unordered_map GRAMMAR_LITERAL_ESCAPES = { {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"} }; -std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; -std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; +static std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; +static std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; @@ -322,19 +326,19 @@ private: if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) { _rules[esc_name] = rule; return esc_name; - } else { - int i = 0; - while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { - i++; - } - std::string key = esc_name + std::to_string(i); - _rules[key] = rule; - return key; } + int i = 0; + while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { + i++; + } + std::string key = esc_name + std::to_string(i); + _rules[key] = rule; + return key; } std::string _generate_union_rule(const std::string & name, const std::vector & alt_schemas) { std::vector rules; + rules.reserve(alt_schemas.size()); for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } @@ -398,6 +402,7 @@ private: flush_literal(); std::vector results; + results.reserve(ret.size()); for (const auto & item : ret) { results.push_back(to_rule(item)); } @@ -551,7 +556,7 @@ private: TrieNode() : is_end_of_string(false) {} void insert(const std::string & string) { - auto node = this; + auto *node = this; for (char c : string) { node = &node->children[c]; } @@ -676,7 +681,7 @@ private: if (ks.empty()) { return res; } - std::string k = ks[0]; + const std::string& k = ks[0]; std::string kv_rule_name = prop_kv_rule_names[k]; std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; if (first_is_optional) { @@ -779,7 +784,7 @@ public: std::string pointer = ref.substr(ref.find('#') + 1); std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { - std::string sel = tokens[i]; + const std::string& sel = tokens[i]; if (target.is_object() && target.contains(sel)) { target = target[sel]; } else if (target.is_array()) { @@ -802,7 +807,7 @@ public: _refs[ref] = target; } } else { - for (auto & kv : n.items()) { + for (const auto & kv : n.items()) { visit_refs(kv.value()); } } @@ -812,7 +817,7 @@ public: visit_refs(schema); } - std::string _generate_constant_rule(const json & value) { + static std::string _generate_constant_rule(const json & value) { return format_literal(value.dump()); } @@ -823,10 +828,12 @@ public: if (schema.contains("$ref")) { return _add_rule(rule_name, _resolve_ref(schema["$ref"])); - } else if (schema.contains("oneOf") || schema.contains("anyOf")) { + } + if (schema.contains("oneOf") || schema.contains("anyOf")) { std::vector alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get>() : schema["anyOf"].get>(); return _add_rule(rule_name, _generate_union_rule(name, alt_schemas)); - } else if (schema_type.is_array()) { + } + if (schema_type.is_array()) { std::vector schema_types; for (const auto & t : schema_type) { json schema_copy(schema); @@ -834,15 +841,18 @@ public: schema_types.push_back(schema_copy); } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); - } else if (schema.contains("const")) { + } + if (schema.contains("const")) { return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); - } else if (schema.contains("enum")) { + } + if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); - } else if ((schema_type.is_null() || schema_type == "object") + } + if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { std::unordered_set required; @@ -863,11 +873,12 @@ public: _build_object_rule( properties, required, name, schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { + } + if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { std::unordered_set required; std::vector> properties; std::map enum_values; - std::string hybrid_name = name; + const std::string& hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { if (comp_schema.contains("$ref")) { add_component(_refs[comp_schema["$ref"]], is_required); @@ -890,9 +901,9 @@ public: // todo warning } }; - for (auto & t : schema["allOf"]) { + for (const auto & t : schema["allOf"]) { if (t.contains("anyOf")) { - for (auto & tt : t["anyOf"]) { + for (const auto & tt : t["anyOf"]) { add_component(tt, false); } } else { @@ -911,7 +922,8 @@ public: } } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); - } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { + } + if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; if (items.is_array()) { std::string rule = "\"[\" space "; @@ -923,27 +935,31 @@ public: } rule += " \"]\" space"; return _add_rule(rule_name, rule); - } else { - std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); - int min_items = schema.contains("minItems") ? schema["minItems"].get() : 0; - json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); - int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); - - return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); } - } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { + std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); + int min_items = schema.contains("minItems") ? schema["minItems"].get() : 0; + json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); + int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); + + return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); + } + if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { return _visit_pattern(schema["pattern"], rule_name); - } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) { + } + if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) { return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid")); - } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) { + } + if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) { auto prim_name = schema_format + "-string"; return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name))); - } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) { + } + if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) { std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); - } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { + } + if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { int64_t min_value = std::numeric_limits::min(); int64_t max_value = std::numeric_limits::max(); if (schema.contains("minimum")) { @@ -958,19 +974,24 @@ public: } std::stringstream out; out << "("; - _build_min_max_int(min_value, max_value, out); + build_min_max_int(min_value, max_value, out); out << ") space"; return _add_rule(rule_name, out.str()); - } else if (schema.empty() || schema_type == "object") { - return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); - } else { - if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) { - _errors.push_back("Unrecognized schema: " + schema.dump()); - return ""; - } - // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero - return _add_primitive(rule_name == "root" ? "root" : schema_type.get(), PRIMITIVE_RULES.at(schema_type.get())); } + if (schema.empty() || schema_type == "object") { + return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); + } + if (schema_type.is_null() && schema.is_object()) { + // No type constraint and no recognized structural keywords (e.g. {"description": "..."}). + // Per JSON Schema semantics this is equivalent to {} and accepts any value. + return _add_rule(rule_name, _add_primitive("value", PRIMITIVE_RULES.at("value"))); + } + if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) { + _errors.push_back("Unrecognized schema: " + schema.dump()); + return ""; + } + // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return _add_primitive(rule_name == "root" ? "root" : schema_type.get(), PRIMITIVE_RULES.at(schema_type.get())); } void check_errors() { @@ -985,7 +1006,7 @@ public: std::string format_grammar() { std::stringstream ss; for (const auto & kv : _rules) { - ss << kv.first << " ::= " << kv.second << std::endl; + ss << kv.first << " ::= " << kv.second << '\n'; } return ss.str(); } diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index f2fc84500..48379f1ec 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1,14 +1,15 @@ -#include "common.h" #include "peg-parser.h" -#include "json-schema-to-grammar.h" -#include "unicode.h" -#include +#include "common.h" +#include "json-schema-to-grammar.h" +#include "log.h" +#include "unicode.h" #include #include #include #include +#include #include #include #include @@ -34,8 +35,7 @@ static bool is_hex_digit(const char c) { // This is used in common_peg_until_parser and to build a GBNF exclusion grammar struct trie { struct node { - size_t depth = 0; - std::map children; + std::map children; // Use uint32_t to store Unicode codepoints bool is_word; }; @@ -55,15 +55,22 @@ struct trie { size_t current = 0; // Start at root size_t pos = start_pos; + // LOG_DBG("%s: checking at pos %zu, sv='%s'\n", __func__, start_pos, std::string(sv).c_str()); + while (pos < sv.size()) { - auto it = nodes[current].children.find(sv[pos]); + auto result = common_parse_utf8_codepoint(sv, pos); + if (result.status != utf8_parse_result::SUCCESS) { + break; + } + + auto it = nodes[current].children.find(result.codepoint); if (it == nodes[current].children.end()) { // Can't continue matching return match_result{match_result::NO_MATCH}; } current = it->second; - pos++; + pos += result.bytes_consumed; // Check if we've matched a complete word if (nodes[current].is_word) { @@ -82,22 +89,22 @@ struct trie { } struct prefix_and_next { - std::string prefix; - std::string next_chars; + std::vector prefix; + std::vector next_chars; }; std::vector collect_prefix_and_next() { - std::string prefix; + std::vector prefix; std::vector result; collect_prefix_and_next(0, prefix, result); return result; } private: - void collect_prefix_and_next(size_t index, std::string & prefix, std::vector & out) { + void collect_prefix_and_next(size_t index, std::vector & prefix, std::vector & out) { if (!nodes[index].is_word) { if (!nodes[index].children.empty()) { - std::string chars; + std::vector chars; chars.reserve(nodes[index].children.size()); for (const auto & p : nodes[index].children) { chars.push_back(p.first); @@ -107,7 +114,7 @@ struct trie { } for (const auto & p : nodes[index].children) { - unsigned char ch = p.first; + uint32_t ch = p.first; auto child = p.second; prefix.push_back(ch); collect_prefix_and_next(child, prefix, out); @@ -123,11 +130,19 @@ struct trie { void insert(const std::string & word) { size_t current = 0; - for (unsigned char ch : word) { + size_t pos = 0; + while (pos < word.length()) { + auto result = common_parse_utf8_codepoint(word, pos); + if (result.status != utf8_parse_result::SUCCESS) { + break; + } + + uint32_t ch = result.codepoint; + pos += result.bytes_consumed; + auto it = nodes[current].children.find(ch); if (it == nodes[current].children.end()) { size_t child = create_node(); - nodes[child].depth = nodes[current].depth + 1; nodes[current].children[ch] = child; current = child; } else { @@ -286,6 +301,32 @@ struct parser_executor { parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start) : arena(arena), ctx(ctx), start_pos(start) {} + std::string debug_indent() const { return std::string(ctx.parse_depth * 2, ' '); } + + std::string debug_input_snippet(size_t pos, size_t len = 60) const { + if (pos >= ctx.input.size()) { + return ""; + } + auto snippet = ctx.input.substr(pos, len); + // Escape newlines for display + std::string result; + for (char c : snippet) { + if (c == '\n') { + result += "\\n"; + } else if (c == '\r') { + result += "\\r"; + } else if (c == '\t') { + result += "\\t"; + } else { + result += c; + } + } + if (pos + len < ctx.input.size()) { + result += "..."; + } + return result; + } + common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const { return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); } @@ -323,12 +364,39 @@ struct parser_executor { } common_peg_parse_result operator()(const common_peg_sequence_parser & p) { + if (ctx.debug) { + LOG_DBG("%sSEQ start at %zu '%s' (%zu children)\n", debug_indent().c_str(), start_pos, + debug_input_snippet(start_pos).c_str(), p.children.size()); + } + ctx.parse_depth++; + auto pos = start_pos; std::vector nodes; - for (const auto & child_id : p.children) { + for (size_t i = 0; i < p.children.size(); i++) { + const auto & child_id = p.children[i]; + if (ctx.debug) { + fprintf(stderr, "%sSEQ child %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str()); + } auto result = arena.parse(child_id, ctx, pos); + + if (ctx.debug) { + fprintf(stderr, "%sSEQ child %zu: %s at %zu->%zu\n", debug_indent().c_str(), i, + common_peg_parse_result_type_name(result.type), result.start, result.end); + } + if (result.fail()) { + ctx.parse_depth--; + if (ctx.is_partial && result.end >= ctx.input.size()) { + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> NEED_MORE (child failed at end)\n", debug_indent().c_str()); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, + std::move(nodes)); + } + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> FAIL\n", debug_indent().c_str()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end); } @@ -337,28 +405,65 @@ struct parser_executor { } if (result.need_more_input()) { + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> NEED_MORE\n", debug_indent().c_str()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); } pos = result.end; } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> SUCCESS at %zu->%zu\n", debug_indent().c_str(), start_pos, pos); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); } common_peg_parse_result operator()(const common_peg_choice_parser & p) { + if (ctx.debug) { + fprintf(stderr, "%sCHOICE start at %zu '%s' (%zu options)\n", debug_indent().c_str(), start_pos, + debug_input_snippet(start_pos).c_str(), p.children.size()); + } + ctx.parse_depth++; + auto pos = start_pos; - for (const auto & child_id : p.children) { + for (size_t i = 0; i < p.children.size(); i++) { + const auto & child_id = p.children[i]; + if (ctx.debug) { + fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str()); + } auto result = arena.parse(child_id, ctx, pos); + if (ctx.debug) { + fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, + common_peg_parse_result_type_name(result.type)); + } if (!result.fail()) { + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sCHOICE -> %s (option %zu)\n", debug_indent().c_str(), + common_peg_parse_result_type_name(result.type), i); + } return result; } } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sCHOICE -> FAIL (no options matched)\n", debug_indent().c_str()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); } common_peg_parse_result operator()(const common_peg_repetition_parser & p) { + if (ctx.debug) { + fprintf(stderr, "%sREPEAT start at %zu '%s' (min=%d, max=%d)\n", debug_indent().c_str(), start_pos, + debug_input_snippet(start_pos).c_str(), p.min_count, p.max_count); + } + ctx.parse_depth++; + auto pos = start_pos; int match_count = 0; std::vector nodes; @@ -366,14 +471,26 @@ struct parser_executor { // Try to match up to max_count times (or unlimited if max_count is -1) while (p.max_count == -1 || match_count < p.max_count) { if (pos >= ctx.input.size()) { + if (ctx.debug) { + fprintf(stderr, "%sREPEAT: at end of input, count=%d\n", debug_indent().c_str(), match_count); + } break; } auto result = arena.parse(p.child, ctx, pos); + if (ctx.debug) { + fprintf(stderr, "%sREPEAT iter %d: %s at %zu->%zu, nodes=%zu\n", debug_indent().c_str(), match_count, + common_peg_parse_result_type_name(result.type), result.start, result.end, result.nodes.size()); + fprintf(stderr, "%sREPEAT CHILD: %s\n", debug_indent().c_str(), arena.dump(p.child).c_str()); + } + if (result.success()) { // Prevent infinite loop on empty matches if (result.end == pos) { + if (ctx.debug) { + fprintf(stderr, "%s REPEAT: empty match, stopping\n", debug_indent().c_str()); + } break; } @@ -391,21 +508,43 @@ struct parser_executor { nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> NEED_MORE (count=%d, nodes=%zu)\n", debug_indent().c_str(), + match_count, nodes.size()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); } // Child failed - stop trying + if (ctx.debug) { + fprintf(stderr, "%sREPEAT: child failed, stopping\n", debug_indent().c_str()); + } break; } // Check if we got enough matches if (p.min_count > 0 && match_count < p.min_count) { + ctx.parse_depth--; if (pos >= ctx.input.size() && ctx.is_partial) { + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> NEED_MORE (not enough matches: %d < %d)\n", debug_indent().c_str(), + match_count, p.min_count); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes)); } + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> FAIL (not enough matches: %d < %d)\n", debug_indent().c_str(), match_count, + p.min_count); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> SUCCESS (count=%d, nodes=%zu)\n", debug_indent().c_str(), match_count, + nodes.size()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); } @@ -434,7 +573,7 @@ struct parser_executor { common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const { // Parse a single UTF-8 codepoint (not just a single byte) - auto result = parse_utf8_codepoint(ctx.input, start_pos); + auto result = common_parse_utf8_codepoint(ctx.input, start_pos); if (result.status == utf8_parse_result::INCOMPLETE) { if (!ctx.is_partial) { @@ -468,7 +607,7 @@ struct parser_executor { // Try to match up to max_count times (or unlimited if max_count is -1) while (p.max_count == -1 || match_count < p.max_count) { - auto result = parse_utf8_codepoint(ctx.input, pos); + auto result = common_parse_utf8_codepoint(ctx.input, pos); if (result.status == utf8_parse_result::INCOMPLETE) { if (match_count >= p.min_count) { @@ -537,6 +676,7 @@ struct parser_executor { switch (ctx.input[pos]) { case '"': + case '\'': case '\\': case '/': case 'b': @@ -589,7 +729,49 @@ struct parser_executor { return result; } } else { - auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + pos += utf8_result.bytes_consumed; + } + } + + // Reached end without finding closing quote + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) { + auto pos = start_pos; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '\'') { + // Found closing quote - success (don't consume it) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (c == '\\') { + auto result = handle_escape_sequence(ctx, start_pos, pos); + if (!result.success()) { + return result; + } + } else { + auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos); if (utf8_result.status == utf8_parse_result::INCOMPLETE) { if (!ctx.is_partial) { @@ -621,7 +803,7 @@ struct parser_executor { size_t last_valid_pos = start_pos; while (pos < ctx.input.size()) { - auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos); if (utf8_result.status == utf8_parse_result::INCOMPLETE) { // Incomplete UTF-8 sequence @@ -694,6 +876,9 @@ struct parser_executor { common_peg_parse_result operator()(const common_peg_tag_parser & p) { // Parse the child + if (ctx.debug) { + fprintf(stderr, "%sTAG: %s\n", debug_indent().c_str(), p.tag.c_str()); + } auto result = arena.parse(p.child, ctx, start_pos); if (!result.fail()) { @@ -755,6 +940,31 @@ common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) { return id; } +static void bfs_node(common_peg_ast_arena &arena, std::ostringstream & oss, const common_peg_ast_node & node, int indent) { + for (int i = 0; i < indent; i++) { + oss << " "; + } + oss << "NODE " << node.id; + if (!node.rule.empty()) { + oss << " (rule " << node.rule << ")"; + } + if (!node.tag.empty()) { + oss << " (tag " << node.tag << ")"; + } + oss << " ['" << node.text << "']\n"; + for (const auto child : node.children) { + bfs_node(arena, oss, arena.get(child), indent + 1); + } +} + +std::string common_peg_ast_arena::dump() { + std::ostringstream oss; + for (auto & node : nodes_) { + bfs_node(*this, oss, node, 0); + } + return oss.str(); +} + void common_peg_arena::resolve_refs() { // Walk through all parsers and replace refs with their corresponding rule IDs for (auto & parser : parsers_) { @@ -786,6 +996,7 @@ void common_peg_arena::resolve_refs() { std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { @@ -803,9 +1014,21 @@ void common_peg_arena::resolve_refs() { } std::string common_peg_arena::dump(common_peg_parser_id id) const { + std::unordered_set visited; + return dump_impl(id, visited); +} + +std::string common_peg_arena::dump_impl(common_peg_parser_id id, + std::unordered_set & visited) const { + // Check for cycles + if (visited.count(id)) { + return "[cycle]"; + } + visited.insert(id); + const auto & parser = parsers_.at(id); - return std::visit([this](const auto & p) -> std::string { + return std::visit([this, &visited](const auto & p) -> std::string { using T = std::decay_t; if constexpr (std::is_same_v) { @@ -819,24 +1042,27 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const { } else if constexpr (std::is_same_v) { std::vector parts; for (const auto & child : p.children) { - parts.push_back(dump(child)); + parts.push_back(dump_impl(child, visited)); } return "Sequence(" + string_join(parts, ", ") + ")"; } else if constexpr (std::is_same_v) { std::vector parts; for (const auto & child : p.children) { - parts.push_back(dump(child)); + parts.push_back(dump_impl(child, visited)); } return "Choice(" + string_join(parts, ", ") + ")"; } else if constexpr (std::is_same_v) { if (p.max_count == -1) { - return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)"; + return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) + + ", unbounded)"; } - return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; } else if constexpr (std::is_same_v) { - return "And(" + dump(p.child) + ")"; + return "And(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { - return "Not(" + dump(p.child) + ")"; + return "Not(" + dump_impl(p.child, visited) + ")"; + } else if constexpr (std::is_same_v) { + return "Atomic(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Any"; } else if constexpr (std::is_same_v) { @@ -848,14 +1074,20 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const { return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; } else if constexpr (std::is_same_v) { return "JsonString()"; + } else if constexpr (std::is_same_v) { + return "PythonDictString()"; } else if constexpr (std::is_same_v) { return "Until(" + string_join(p.delimiters, " | ") + ")"; } else if constexpr (std::is_same_v) { - return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; + return "Schema(" + dump_impl(p.child, visited) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; } else if constexpr (std::is_same_v) { - return "Rule(" + p.name + ", " + dump(p.child) + ")"; + return "Rule(" + p.name + ", " + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Ref(" + p.name + ")"; + } else if constexpr (std::is_same_v) { + return "Tag(" + p.tag + ", " + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Atomic(" + dump(p.child) + ")"; } else { return "Unknown"; } @@ -1054,7 +1286,54 @@ common_peg_arena common_peg_parser_builder::build() { return std::move(arena_); } +// String primitives + +common_peg_parser common_peg_parser_builder::json_string_content() { + return wrap(arena_.add_parser(common_peg_json_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::single_quoted_string_content() { + return wrap(arena_.add_parser(common_peg_python_dict_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::double_quoted_string() { + return rule("dq-string", + [this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); }); +} + +common_peg_parser common_peg_parser_builder::single_quoted_string() { + return rule("sq-string", + [this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); }); +} + +common_peg_parser common_peg_parser_builder::flexible_string() { + return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); }); +} + +// Generic helpers for object/array structure + +common_peg_parser common_peg_parser_builder::generic_object(const std::string & name, + const common_peg_parser & string_parser, + const common_peg_parser & value_parser) { + return rule(name, [this, string_parser, value_parser]() { + auto ws = space(); + auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser }); + auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) }); + return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) }); + }); +} + +common_peg_parser common_peg_parser_builder::generic_array(const std::string & name, + const common_peg_parser & value_parser) { + return rule(name, [this, value_parser]() { + auto ws = space(); + auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) }); + return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) }); + }); +} + // JSON parsers + common_peg_parser common_peg_parser_builder::json_number() { return rule("json-number", [this]() { auto digit1_9 = chars("[1-9]", 1, 1); @@ -1062,7 +1341,11 @@ common_peg_parser common_peg_parser_builder::json_number() { auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})}); auto frac = sequence({literal("."), digits}); auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits}); - return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()}); + // Negative lookahead: only commit the number when the next character can't extend it. + // At EOF in partial mode, chars returns NEED_MORE → negate propagates NEED_MORE → number not committed. + // This prevents premature commits of partial numbers (e.g. "3" when "3.14" is incoming). + auto not_number_continuation = negate(chars("[0-9.eE+-]", 1, 1)); + return sequence({ optional(literal("-")), int_part, optional(frac), optional(exp), not_number_continuation, space() }); }); } @@ -1085,36 +1368,11 @@ common_peg_parser common_peg_parser_builder::json_null() { } common_peg_parser common_peg_parser_builder::json_object() { - return rule("json-object", [this]() { - auto ws = space(); - auto member = sequence({json_string(), ws, literal(":"), ws, json()}); - auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))}); - return sequence({ - literal("{"), - ws, - choice({ - literal("}"), - sequence({members, ws, literal("}")}) - }), - ws - }); - }); + return generic_object("json-object", json_string(), json()); } common_peg_parser common_peg_parser_builder::json_array() { - return rule("json-array", [this]() { - auto ws = space(); - auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); - return sequence({ - literal("["), - ws, - choice({ - literal("]"), - sequence({elements, ws, literal("]")}) - }), - ws - }); - }); + return generic_array("json-array", json()); } common_peg_parser common_peg_parser_builder::json() { @@ -1130,8 +1388,40 @@ common_peg_parser common_peg_parser_builder::json() { }); } -common_peg_parser common_peg_parser_builder::json_string_content() { - return wrap(arena_.add_parser(common_peg_json_string_parser{})); +common_peg_parser common_peg_parser_builder::python_string() { + return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); }); +} + +common_peg_parser common_peg_parser_builder::python_number() { + return json_number(); +} + +common_peg_parser common_peg_parser_builder::python_bool() { + return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); }); +} + +common_peg_parser common_peg_parser_builder::python_null() { + return rule("python-none", [this]() { return sequence({ literal("None"), space() }); }); +} + +common_peg_parser common_peg_parser_builder::python_dict() { + return generic_object("python-dict", python_string(), python_value()); +} + +common_peg_parser common_peg_parser_builder::python_array() { + return generic_array("python-array", python_value()); +} + +common_peg_parser common_peg_parser_builder::python_value() { + return rule("python-value", [this]() { + return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() }); + }); +} + +common_peg_parser common_peg_parser_builder::marker() { + auto sharp_bracket_parser = literal("<") + until(">") + literal(">"); + auto square_bracket_parser = literal("[") + until("]") + literal("]"); + return choice({ sharp_bracket_parser, square_bracket_parser }); } common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) { @@ -1145,17 +1435,54 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key }); } - -static std::string gbnf_escape_char_class(char c) { - switch (c) { - case '\n': return "\\n"; - case '\t': return "\\t"; - case '\r': return "\\r"; - case '\\': return "\\\\"; - case ']': return "\\]"; - case '[': return "\\["; - default: return std::string(1, c); +static std::string gbnf_escape_char_class(uint32_t c) { + if (c == '-' || c == ']' || c == '[' || c == '\\') { + return "\\" + std::string(1, (char) c); } + // Escape whitespace control characters + if (c == '\n') { + return "\\n"; + } + if (c == '\t') { + return "\\t"; + } + if (c == '\r') { + return "\\r"; + } + + // Printable ASCII + if (c >= 0x20 && c <= 0x7E) { + return std::string(1, (char) c); + } + + // Hex escape + char buf[16]; + const char * hex = "0123456789ABCDEF"; + + if (c <= 0xFF) { + buf[0] = '\\'; + buf[1] = 'x'; + buf[2] = hex[(c >> 4) & 0xF]; + buf[3] = hex[c & 0xF]; + buf[4] = '\0'; + } else if (c <= 0xFFFF) { + buf[0] = '\\'; + buf[1] = 'u'; + buf[2] = hex[(c >> 12) & 0xF]; + buf[3] = hex[(c >> 8) & 0xF]; + buf[4] = hex[(c >> 4) & 0xF]; + buf[5] = hex[c & 0xF]; + buf[6] = '\0'; + } else { + buf[0] = '\\'; + buf[1] = 'U'; + for (int i = 0; i < 8; i++) { + buf[2 + i] = hex[(c >> ((7 - i) * 4)) & 0xF]; + } + buf[10] = '\0'; + } + + return std::string(buf); } static std::string gbnf_excluding_pattern(const std::vector & strings) { @@ -1173,12 +1500,12 @@ static std::string gbnf_excluding_pattern(const std::vector & strin std::string cls; cls.reserve(chars.size()); - for (const auto & ch : chars) { + for (uint32_t ch : chars) { cls += gbnf_escape_char_class(ch); } if (!pre.empty()) { - pattern += gbnf_format_literal(pre) + " [^" + cls + "]"; + pattern += gbnf_format_literal(common_unicode_cpts_to_utf8(pre)) + " [^" + cls + "]"; } else { pattern += "[^" + cls + "]"; } @@ -1208,7 +1535,8 @@ static std::unordered_set collect_reachable_rules( std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { // These parsers do not have any children } else if constexpr (std::is_same_v) { for (auto child : p.children) { @@ -1346,6 +1674,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; } else if constexpr (std::is_same_v) { return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; + } else if constexpr (std::is_same_v) { + return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; } else if constexpr (std::is_same_v) { if (p.delimiters.empty()) { return ".*"; @@ -1477,6 +1807,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & }; } else if constexpr (std::is_same_v) { return json{{"type", "json_string"}}; + } else if constexpr (std::is_same_v) { + return json{{ "type", "python_dict_string" }}; } else if constexpr (std::is_same_v) { return json{{"type", "until"}, {"delimiters", p.delimiters}}; } else if constexpr (std::is_same_v) { @@ -1606,6 +1938,9 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json if (type == "json_string") { return common_peg_json_string_parser{}; } + if (type == "python_dict_string") { + return common_peg_python_dict_string_parser{}; + } if (type == "until") { if (!j.contains("delimiters") || !j["delimiters"].is_array()) { throw std::runtime_error("until parser missing or invalid 'delimiters' field"); diff --git a/common/peg-parser.h b/common/peg-parser.h index 1cd640365..57d4bcd8e 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -111,6 +112,8 @@ class common_peg_ast_arena { void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const; void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const; + + std::string dump(); }; struct common_peg_parse_result { @@ -139,6 +142,7 @@ struct common_peg_parse_result { struct common_peg_parse_context { std::string input; bool is_partial; + bool debug = false; // Enable debug output for parser tracing common_peg_ast_arena ast; int parse_depth; @@ -207,6 +211,7 @@ struct common_peg_chars_parser { }; struct common_peg_json_string_parser {}; +struct common_peg_python_dict_string_parser {}; struct common_peg_until_parser { std::vector delimiters; @@ -255,6 +260,7 @@ using common_peg_parser_variant = std::variant< common_peg_space_parser, common_peg_chars_parser, common_peg_json_string_parser, + common_peg_python_dict_string_parser, common_peg_until_parser, common_peg_schema_parser, common_peg_rule_parser, @@ -299,6 +305,8 @@ class common_peg_arena { friend class common_peg_parser_builder; private: + std::string dump_impl(common_peg_parser_id id, std::unordered_set & visited) const; + common_peg_parser_id add_parser(common_peg_parser_variant parser); void add_rule(const std::string & name, common_peg_parser_id id); @@ -311,6 +319,10 @@ class common_peg_parser_builder { common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); } common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); } + // Generic helpers for building object/array structures with configurable string/value parsers. + common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser); + common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser); + public: common_peg_parser_builder(); @@ -404,6 +416,21 @@ class common_peg_parser_builder { // S -> A{n} common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); } + // Matches a double-quoted string: '"' content '"' space + common_peg_parser double_quoted_string(); + + // Matches a single-quoted string: "'" content "'" space + common_peg_parser single_quoted_string(); + + // Matches a string that accepts both double-quoted and single-quoted styles. + common_peg_parser flexible_string(); + + // Matches double-quoted string content without the surrounding quotes. + common_peg_parser json_string_content(); + + // Matches single-quoted string content without the surrounding quotes. + common_peg_parser single_quoted_string_content(); + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. // value -> object | array | string | number | true | false | null common_peg_parser json(); @@ -414,14 +441,24 @@ class common_peg_parser_builder { common_peg_parser json_bool(); common_peg_parser json_null(); - // Matches JSON string content without the surrounding quotes. - // Useful for extracting content within a JSON string. - common_peg_parser json_string_content(); - // Matches a JSON object member with a key and associated parser as the // value. common_peg_parser json_member(const std::string & key, const common_peg_parser & p); + // Creates a complete Python format parser supporting dicts, arrays, strings, numbers, booleans, and None. + // Differs from JSON: uses True/False/None, accepts both single and double-quoted strings. + // value -> dict | array | string | number | True | False | None + common_peg_parser python_value(); + common_peg_parser python_dict(); + common_peg_parser python_string(); + common_peg_parser python_array(); + common_peg_parser python_number(); + common_peg_parser python_bool(); + common_peg_parser python_null(); + + // A marker, i.e. text delimited by a pair of <> or [] + common_peg_parser marker(); + // Wraps a parser with JSON schema metadata for grammar generation. // Used internally to convert JSON schemas to GBNF grammar rules. common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false); diff --git a/common/unicode.cpp b/common/unicode.cpp index 56ab0f468..c0ef6d029 100644 --- a/common/unicode.cpp +++ b/common/unicode.cpp @@ -1,14 +1,18 @@ #include "unicode.h" +#include +#include +#include +#include // implementation adopted from src/unicode.cpp -size_t utf8_sequence_length(unsigned char first_byte) { +size_t common_utf8_sequence_length(unsigned char first_byte) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(first_byte) >> 4; return lookup[highbits]; } -utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { +utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset) { if (offset >= input.size()) { return utf8_parse_result(utf8_parse_result::INCOMPLETE); } @@ -62,3 +66,43 @@ utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { // Invalid first byte return utf8_parse_result(utf8_parse_result::INVALID); } + +std::string common_unicode_cpts_to_utf8(const std::vector & cps) { + std::string result; + for (size_t i = 0; i < cps.size(); ++i) { + result.append(common_unicode_cpt_to_utf8(cps[i])); + } + return result; +} + +std::string common_unicode_cpt_to_utf8(uint32_t cpt) { + std::string result; + + if (/* 0x00 <= cpt && */ cpt <= 0x7f) { + result.push_back(cpt); + return result; + } + if (0x80 <= cpt && cpt <= 0x7ff) { + result.push_back(0xc0 | ((cpt >> 6) & 0x1f)); + result.push_back(0x80 | (cpt & 0x3f)); + return result; + } + if (0x800 <= cpt && cpt <= 0xffff) { + result.push_back(0xe0 | ((cpt >> 12) & 0x0f)); + result.push_back(0x80 | ((cpt >> 6) & 0x3f)); + result.push_back(0x80 | (cpt & 0x3f)); + return result; + } + if (0x10000 <= cpt && cpt <= 0x10ffff) { + result.push_back(0xf0 | ((cpt >> 18) & 0x07)); + result.push_back(0x80 | ((cpt >> 12) & 0x3f)); + result.push_back(0x80 | ((cpt >> 6) & 0x3f)); + result.push_back(0x80 | (cpt & 0x3f)); + return result; + } + + throw std::invalid_argument("invalid codepoint"); +} + + + diff --git a/common/unicode.h b/common/unicode.h index 9d9e8e122..87bcc0ffc 100644 --- a/common/unicode.h +++ b/common/unicode.h @@ -2,6 +2,8 @@ #include #include +#include +#include // UTF-8 parsing utilities for streaming-aware unicode support @@ -16,7 +18,10 @@ struct utf8_parse_result { // Determine the expected length of a UTF-8 sequence from its first byte // Returns 0 for invalid first bytes -size_t utf8_sequence_length(unsigned char first_byte); +size_t common_utf8_sequence_length(unsigned char first_byte); // Parse a single UTF-8 codepoint from input -utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset); +utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset); + +std::string common_unicode_cpts_to_utf8(const std::vector & cps); +std::string common_unicode_cpt_to_utf8(uint32_t cpt); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e37944c4e..7bf45a4c0 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -562,6 +562,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_GATED_DELTA_NET, GGML_OP_UNARY, @@ -2481,6 +2482,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2b5a352da..799a24619 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2793,6 +2793,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_GATED_DELTA_NET: + { + ggml_compute_forward_gated_delta_net(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -3013,6 +3017,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_COUNT_EQUAL: case GGML_OP_SOLVE_TRI: + case GGML_OP_GATED_DELTA_NET: { n_tasks = n_threads; } break; @@ -3742,6 +3747,11 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_GATED_DELTA_NET: + { + const int64_t S_v = node->src[2]->ne[0]; + cur = S_v * sizeof(float) * n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2c372f963..331e071a2 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10380,6 +10380,190 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s } } +// ggml_compute_forward_gated_delta_net +static void ggml_compute_forward_gated_delta_net_one_chunk( + const ggml_compute_params * params, + ggml_tensor * dst, + int64_t ir0, + int64_t ir1) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v); + GGML_ASSERT(src_beta->ne[0] == 1); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const bool kda = (neg0 == S_v); + + // scratch layout per thread: [delta(S_v)] + const int64_t scratch_per_thread = S_v; + const int ith = params->ith; + + float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs floats + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_out_base = (float *)dst->data; + float * state_out_base = (float *)dst->data + attn_score_elems; + + const float * state_in_base = (const float *)src_state->data; + + const int64_t rq1 = nev1 / neq1; + const int64_t rk1 = nev1 / nek1; + const int64_t rq3 = nev3 / neq3; + const int64_t rk3 = nev3 / nek3; + + const float scale = 1.0f / sqrtf((float) S_v); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t iv1 = ir % H; // head_index + const int64_t iv3 = ir / H; // sequence + + const int64_t iq1 = iv1 / rq1; + const int64_t ik1 = iv1 / rk1; + + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + + // copy input state into output buffer and operate in-place + const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; + + for (int64_t t = 0; t < n_tokens; t++) { + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); + const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); + const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); + + const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + + if (kda) { + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i])); + } + } else { + ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); + } + + // delta[j] = sum_i S[j][i] * k[i] + memset(delta, 0, S_v * sizeof(float)); + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]); + } + for (int64_t j = 0; j < S_v; ++j) { + delta[j] = (v_d[j] - delta[j]) * beta_val; + } + + // outer product: S[j][i] += k[i] * delta[j] + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]); + } + + // attn_out[j] = sum_i S[j][i] * q[i] + memset(attn_data, 0, S_v * sizeof(float)); + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]); + } + ggml_vec_scale_f32(S_v, attn_data, scale); + + attn_data += S_v * H; // advance to next token + } + + } +} + + +static void ggml_compute_forward_gated_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + ggml_tensor * V = dst->src[2]; + int64_t nr = V->ne[1] * V->ne[3]; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + int nth = params->nth; + int ith = params->ith; + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } +} + +void ggml_compute_forward_gated_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gated_delta_net_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee797..3fa1443ab 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu new file mode 100644 index 000000000..d8e811145 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -0,0 +1,223 @@ +#include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" + +template +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + int64_t rq1, + int64_t rq3, + float scale) { + const int64_t h_idx = blockIdx.x; + const int64_t sequence = blockIdx.y; + const int col = threadIdx.x; // each thread owns one column + + const int64_t iq1 = h_idx / rq1; + const int64_t iq3 = sequence / rq3; + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_offset; + curr_state += state_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + // Load state column into registers + float s[S_v]; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = curr_state[i * S_v + col]; + } + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + if constexpr (!KDA) { + const float g_val = expf(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += s[i] * k_t[i]; + } + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = g_val * s[i] + k_t[i] * delta_col; + attn_col += s[i] * q_t[i]; + } + + attn_data[col] = attn_col * scale; + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += expf(g_t[i]) * s[i] * k_t[i]; + } + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col; + attn_col += s[i] * q_t[i]; + } + + attn_data[col] = attn_col * scale; + } + + attn_data += S_v * H; + } + + // Write state back to global memory +#pragma unroll + for (int i = 0; i < S_v; i++) { + state[i * S_v + col] = s[i]; + } +} + +template +static void launch_gated_delta_net( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t rq1, int64_t rq3, + float scale, cudaStream_t stream) { + + dim3 grid_dims(H, n_seqs, 1); + dim3 block_dims(S_v, 1, 1); + + switch (S_v) { + case 32: + gated_delta_net_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + case 64: + gated_delta_net_cuda<64, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + case 128: + gated_delta_net_cuda<128, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + const int64_t rq1 = nev1 / neq1; + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + cudaStream_t stream = ctx.stream(); + + if (kda) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale, stream); + } +} diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh new file mode 100644 index 000000000..7375e81c0 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" +#include "ggml.h" + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c08ff6f63..6f8c2174e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -55,6 +55,7 @@ bool g_mul_mat_q = true; #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/gated_delta_net.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" @@ -2745,6 +2746,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_cuda_op_gated_delta_net(ctx, dst); + break; case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; @@ -4993,6 +4997,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: case GGML_OP_RWKV_WKV7: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index de5cbd75e..e8e25633f 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3104,6 +3104,11 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML } float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); float eff_max = scale*kMaxQ; + if (eff_max <= 0) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } float best = 0; for (int is = -6; is <= 6; ++is) { float id = (2*kMaxQ-1+is*0.1f)/eff_max; @@ -3273,9 +3278,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_ } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS) { scales[ib] = 0; - memset(L, 0, 16); continue; } float best = 0; @@ -3714,9 +3719,9 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + memset(L, 0, 32); if (max < GROUP_MAX_EPS_IQ3_XXS) { scales[ib] = 0; - memset(L, 0, 32); continue; } float best = 0; @@ -3922,6 +3927,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + memset(L, 0, block_size); if (!max) { scales[ib] = 0; continue; @@ -4245,6 +4251,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_S) { scales[ib] = 0; + shifts[ib] = 1; memset(L, 1, block_size); continue; } @@ -4285,7 +4292,12 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + if (besti1 < 0 || besti2 < 0 || best_shift == 0) { + scales[ib] = 0; + shifts[ib] = 1; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4429,6 +4441,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_M) { scales[ib] = 0; + shifts[ib] = 0; memset(L, 1, block_size); continue; } @@ -4527,7 +4540,12 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + if (besti1 < 0 || besti2 < 0 || best_k < 0) { + scales[ib] = 0; + shifts[ib] = 0; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4874,6 +4892,7 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS_IQ2_S) { scales[ib] = 0; continue; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl deleted file mode 100644 index a748dc1b8..000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ /dev/null @@ -1,141 +0,0 @@ -enable f16; - -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - offset_merged_src0: u32, - offset_merged_src1: u32, - - stride_src0_0: u32, - stride_src0_1: u32, - stride_src0_2: u32, - stride_src0_3: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - -fn src0_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - return a_i0 * params.stride_src0_0 + - a_i1 * params.stride_src0_1 + - a_i2 * params.stride_src0_2 + - a_i3 * params.stride_src0_3; -} - -fn src1_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - return b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; -} - -#ifdef TYPE_F32 -#define DataType f32 -#endif -#ifdef TYPE_F16 -#define DataType f16 -#endif - -#ifdef SRC_OVERLAP -@group(0) @binding(0) -var merged_src: array; - -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; -#else -@group(0) @binding(0) -var src0: array; - -@group(0) @binding(1) -var src1 : array; -#if defined(INPLACE) || defined(OVERLAP) -@group(0) @binding(2) -var params: Params; - -#else -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#endif -#endif - -fn op(a: DataType, b: DataType) -> DataType { -#ifdef OP_ADD - return a + b; -#elif defined(OP_SUB) - return a - b; -#elif defined(OP_MUL) - return a * b; -#elif defined(OP_DIV) - return a / b; -#endif -} - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { -#ifdef SRC_OVERLAP - let result = op(merged_src[src0_i], merged_src[src1_i]); -#else - let result = op(src0[src0_i], src1[src1_i]); -#endif - -#ifdef INPLACE - src0[src0_i] = result; -#elif defined(OVERLAP) - src1[src1_i] = result; -#else - dst[dst_i] = result; -#endif -} - -@compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); - update(params.offset_dst + gid.x, src0_i, src1_i); - } -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl deleted file mode 100644 index a22d245d2..000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +++ /dev/null @@ -1,75 +0,0 @@ -struct Params { - ne: u32, - - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src0_0: u32, - stride_src0_1: u32, - stride_src0_2: u32, - stride_src0_3: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - dim: u32, - src0_nedim: u32 -}; - -#ifdef TYPE_F32 -#define DataType f32 -#endif -#ifdef TYPE_I32 -#define DataType i32 -#endif - -@group(0) @binding(0) -var src0: array; - -@group(0) @binding(1) -var src1 : array; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; - -@compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - - if (gid.x < params.ne) { - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - var ni = array(i0, i1, i2, i3); - - if (ni[params.dim] < params.src0_nedim) { - let src_i = ni[0] * params.stride_src0_0 + - ni[1] * params.stride_src0_1 + - ni[2] * params.stride_src0_2 + - ni[3] * params.stride_src0_3; - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; - } else { - ni[params.dim] -= params.src0_nedim; - let src_i = ni[0] * params.stride_src1_0 + - ni[1] * params.stride_src1_1 + - ni[2] * params.stride_src1_2 + - ni[3] * params.stride_src1_3; - dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; - } - } -} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e7eab5129..1bea7e8f7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1047,6 +1047,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "GATED_DELTA_NET", "UNARY", @@ -1064,7 +1065,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1156,6 +1157,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "gated_delta_net(q, k, v, g, beta, s)", "unary(x)", @@ -1173,7 +1175,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6140,6 +6142,57 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// ggml_gated_delta_net + +struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous_rows(q)); + GGML_ASSERT(ggml_is_contiguous_rows(k)); + GGML_ASSERT(ggml_is_contiguous_rows(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA) + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT(beta->ne[0] == 1); + + GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); + + // concat output and new_state into a single tensor + // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_GATED_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/json_to_gbnf.py b/json_to_gbnf.py index cb35a0a8c..5018f01fd 100644 --- a/json_to_gbnf.py +++ b/json_to_gbnf.py @@ -689,6 +689,11 @@ class SchemaConverter: elif (schema_type == 'object') or (len(schema) == 0): return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + elif schema_type is None and isinstance(schema, dict): + # No type constraint and no recognized structural keywords (e.g. {"description": "..."}). + # Per JSON Schema semantics this is equivalent to {} and accepts any value. + return self._add_rule(rule_name, self._add_primitive('value', PRIMITIVE_RULES['value'])) + else: assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 51c1b9a6e..34bdd26c6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -153,6 +153,9 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + cparams.fused_gdn_ar = true; + cparams.fused_gdn_ch = false; // TODO: implement + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -431,7 +434,7 @@ void llama_context::sched_reserve() { if (cparams.auto_fa) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); + throw std::runtime_error("failed to reserve graph for Flash Attention check"); } const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; @@ -441,8 +444,7 @@ void llama_context::sched_reserve() { if (n->op != GGML_OP_FLASH_ATTN_EXT) { continue; } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); + ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); @@ -457,6 +459,7 @@ void llama_context::sched_reserve() { break; } } + if (fa_device_mismatch) { cparams.flash_attn = false; LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); @@ -468,6 +471,39 @@ void llama_context::sched_reserve() { cparams.auto_fa = false; } + if (cparams.fused_gdn_ar) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__); + } + } + // reserve worst-case graph int n_splits_pp = -1; int n_nodes_pp = -1; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2da3bbd6f..333922468 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -31,6 +31,8 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool auto_fa; + bool fused_gdn_ar; // use fused gated delta net (autoregressive) + bool fused_gdn_ch; // use fused gated delta net (chunked) bool no_perf; bool warmup; bool op_offload; diff --git a/src/llama-impl.h b/src/llama-impl.h index dfd9fee9f..ee27ac1be 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -70,4 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__" +#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__" diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index c57abbb5b..b0be62fc6 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-impl.h" + // utility to get one slice from the third dimension // input dim: [x, y, c, b] // output dim: [x, y, 1, b] @@ -39,6 +41,13 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + if (cparams.fused_gdn_ch) { + //ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + //cb(result, LLAMA_TENSOR_NAME_FGDNCH, il); + + GGML_ABORT("not implemented yet"); + } + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); @@ -316,6 +325,26 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + if (cparams.fused_gdn_ar) { + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + cb(result, LLAMA_TENSOR_NAME_FGDNAR, il); + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; + } + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index bacf7a4c2..afc5a1aad 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -332,8 +332,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) + std::pair attn_out; if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 22d708f20..17291ec23 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -332,8 +332,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) + std::pair attn_out; if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { diff --git a/tools/parser/CMakeLists.txt b/tools/parser/CMakeLists.txt new file mode 100644 index 000000000..55e0c6343 --- /dev/null +++ b/tools/parser/CMakeLists.txt @@ -0,0 +1,20 @@ +if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) + # this tool is disabled on Windows when building with shared libraries because it uses internal functions not exported with LLAMA_API + set(TARGET llama-debug-template-parser) + add_executable(${TARGET} debug-template-parser.cpp) + target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) + target_compile_features(${TARGET} PRIVATE cxx_std_17) + + if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET} RUNTIME) + endif() +endif() + +set(TARGET llama-template-analysis) +add_executable(${TARGET} template-analysis.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET} RUNTIME) +endif() diff --git a/tools/parser/debug-template-parser.cpp b/tools/parser/debug-template-parser.cpp new file mode 100644 index 000000000..ffa3a5af7 --- /dev/null +++ b/tools/parser/debug-template-parser.cpp @@ -0,0 +1,452 @@ +#include "../src/llama-grammar.h" +#include "chat-auto-parser.h" +#include "chat.h" +#include "common.h" +#include "gguf.h" +#include "jinja/runtime.h" +#include "log.h" + +#include +#include +#include +#include + +#include "nlohmann/json.hpp" +#include "peg-parser.h" + +using json = nlohmann::ordered_json; + +enum class output_mode { + ANALYSIS, // Only output analysis results (default) + TEMPLATE, // Only output rendered template + BOTH // Output both +}; + +enum class input_message_type { + NONE, // Don't render any message scenarios (only analysis) + CONTENT_ONLY, // Simple assistant message with content + REASONING_CONTENT, // Message with reasoning_content + content + TOOL_CALL_ONLY, // Message with tool_calls only + CONTENT_TOOL_CALL, // Message with content + tool_calls + REASONING_TOOL_CALL, // Message with reasoning_content + tool_calls + CONTENT_FAKE_TOOL_CALL, // Message with content but no actual tool_calls (for testing) + ALL // Render all scenarios +}; + +struct debug_options { + std::string template_path; + bool with_tools = true; + bool generation_prompt = true; + bool enable_reasoning = true; + bool debug_jinja = false; + bool force_tool_call = false; + output_mode mode = output_mode::BOTH; + input_message_type input_message = input_message_type::NONE; +}; + +static std::string read_file(const std::string & path) { + std::ifstream fin(path, std::ios::binary); + if (!fin.is_open()) { + throw std::runtime_error("Could not open file: " + path); + } + std::ostringstream buf; + buf << fin.rdbuf(); + return buf.str(); +} + +static std::string read_gguf_chat_template(const std::string & path) { + struct gguf_init_params params = { /*no_alloc =*/true, // We only need metadata, not tensor data + /*ctx=*/nullptr }; + + struct gguf_context * ctx = gguf_init_from_file(path.c_str(), params); + if (ctx == nullptr) { + throw std::runtime_error("Could not open GGUF file: " + path); + } + + const char * key = "tokenizer.chat_template"; + int64_t key_id = gguf_find_key(ctx, key); + + if (key_id == -1) { + gguf_free(ctx); + throw std::runtime_error("GGUF file does not contain chat template key: " + std::string(key)); + } + + const char * template_str = gguf_get_val_str(ctx, key_id); + if (template_str == nullptr) { + gguf_free(ctx); + throw std::runtime_error("GGUF file contains chat template key but value is null"); + } + + std::string result = template_str; + gguf_free(ctx); + return result; +} + +static void print_usage(const char * program_name) { + LOG_ERR("Usage: %s [options]\n", program_name); + LOG_ERR("\nOptions:\n"); + LOG_ERR(" --no-tools Disable tool definitions\n"); + LOG_ERR(" --force-tool-call Set tool calls to forced\n"); + LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n"); + LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n"); + LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n"); + LOG_ERR(" --debug-jinja Enable Jinja fine-grained debug\n"); + LOG_ERR(" --input-message=TYPE Message type to render:\n"); + LOG_ERR(" content_only, reasoning_content, tool_call_only,\n"); + LOG_ERR(" content_tool_call, reasoning_tool_call,\n"); + LOG_ERR(" content_fake_tool_call, all\n"); + LOG_ERR("\nExamples:\n"); + LOG_ERR(" %s template.jinja --input-message=all --generation-prompt=1\n", program_name); + LOG_ERR(" %s template.jinja --output=template --input-message=tool_call_only\n", program_name); +} + +static bool parse_bool_option(const std::string & value) { + return value == "1" || value == "true" || value == "yes"; +} + +static bool parse_options(int argc, char ** argv, debug_options & opts) { + if (argc < 2) { + print_usage(argv[0]); + return false; + } + + opts.template_path = argv[1]; + + for (int i = 2; i < argc; ++i) { + std::string arg = argv[i]; + + if (arg == "--force-tool-call") { + opts.force_tool_call = true; + } else if (arg == "--debug-jinja") { + opts.debug_jinja = true; + } else if (arg == "--no-tools") { + opts.with_tools = false; + } else if (arg.rfind("--generation-prompt=", 0) == 0) { + opts.generation_prompt = parse_bool_option(arg.substr(20)); + } else if (arg.rfind("--enable-reasoning=", 0) == 0) { + opts.enable_reasoning = parse_bool_option(arg.substr(19)); + } else if (arg.rfind("--output=", 0) == 0) { + std::string mode = arg.substr(9); + if (mode == "analysis") { + opts.mode = output_mode::ANALYSIS; + } else if (mode == "template") { + opts.mode = output_mode::TEMPLATE; + } else if (mode == "both") { + opts.mode = output_mode::BOTH; + } else { + LOG_ERR("Unknown output mode: %s\n", mode.c_str()); + return false; + } + } else if (arg.rfind("--input-message=", 0) == 0) { + std::string type = arg.substr(16); + if (type == "content_only") { + opts.input_message = input_message_type::CONTENT_ONLY; + } else if (type == "reasoning_content") { + opts.input_message = input_message_type::REASONING_CONTENT; + } else if (type == "tool_call_only") { + opts.input_message = input_message_type::TOOL_CALL_ONLY; + } else if (type == "content_tool_call") { + opts.input_message = input_message_type::CONTENT_TOOL_CALL; + } else if (type == "reasoning_tool_call") { + opts.input_message = input_message_type::REASONING_TOOL_CALL; + } else if (type == "content_fake_tool_call") { + opts.input_message = input_message_type::CONTENT_FAKE_TOOL_CALL; + } else if (type == "all") { + opts.input_message = input_message_type::ALL; + } else { + LOG_ERR("Unknown input message type: %s\n", type.c_str()); + return false; + } + } else { + LOG_ERR("Unknown option: %s\n", arg.c_str()); + print_usage(argv[0]); + return false; + } + } + + return true; +} + +static json build_user_message() { + return json{ + { "role", "user" }, + { "content", "Hello, please help me with a task." } + }; +} + +static json build_content_only_message() { + return json{ + { "role", "assistant" }, + { "content", "Hello! I'm here to help you with your task." } + }; +} + +static json build_reasoning_content_message() { + return json{ + { "role", "assistant" }, + { "content", "Hello! I'm here to help you with your task." }, + { "reasoning_content", "The user is greeting me and asking for help. I should respond politely." } + }; +} + +static json build_tool_call_only_message() { + return json{ + { "role", "assistant" }, + { "content", nullptr }, + { "tool_calls", + json::array({ json{ + { "type", "function" }, + { "function", json{ { "name", "test_function_name" }, + { "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } }, + { "id", "123456789" } } }) } + }; +} + +static json build_content_tool_call_message() { + return json{ + { "role", "assistant" }, + { "content", "I'll help you by calling a function." }, + { "tool_calls", + json::array({ json{ + { "type", "function" }, + { "function", + json{ { "name", "test_function_name" }, + { "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } } } }) } + }; +} + +static json build_reasoning_tool_call_message() { + return json{ + { "role", "assistant" }, + { "content", nullptr }, + { "reasoning_content", "I need to call a function to help with this task." }, + { "tool_calls", + json::array({ json{ + { "type", "function" }, + { "function", + json{ { "name", "test_function_name" }, + { "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } } } }) } + }; +} + +static json build_content_fake_tool_call_message() { + // This message has content but NO tool_calls field + // It's used to test if a template renders tool definitions but not tool calls + return json{ + { "role", "assistant" }, + { "content", "I'll help you by calling a function." } + }; +} + +static json build_tools_definition() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["param1"] = json::object({ + { "type", "string" }, + { "description", "First parameter" } + }); + parameters_schema["properties"]["param2"] = json::object({ + { "type", "string" }, + { "description", "Second parameter" } + }); + parameters_schema["required"] = json::array({ "param1" }); + + return json::array({ + json{ { "type", "function" }, + { "function", json{ { "name", "test_function_name" }, + { "description", "A test function for debugging" }, + { "parameters", parameters_schema } } } } + }); +} + +static void render_scenario(const common_chat_template & tmpl, + const std::string & scenario_name, + const json & messages, + const json & tools, + bool add_generation_prompt, + bool enable_thinking) { + LOG_ERR("\n=== Scenario: %s ===\n", scenario_name.c_str()); + LOG_ERR("add_generation_prompt: %s, enable_thinking: %s\n", add_generation_prompt ? "true" : "false", + enable_thinking ? "true" : "false"); + + // When add_generation_prompt is true, add a trailing user message to trigger the prompt + json final_messages = messages; + if (add_generation_prompt && !messages.empty() && messages.back().value("role", "") == "assistant") { + final_messages.push_back(json{ + { "role", "user" }, + { "content", "Now please continue with another response." } + }); + } + + LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str()); + + try { + autoparser::templates_params inputs; + inputs.messages = final_messages; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context["enable_thinking"] = enable_thinking; + + if (!tools.is_null() && tools.is_array() && !tools.empty()) { + inputs.tools = tools; + } + + std::string output = common_chat_template_direct_apply(tmpl, inputs); + + LOG_ERR("\n--- Rendered Output ---\n"); + LOG_ERR("%s\n", output.c_str()); + LOG_ERR("--- End Output (length: %zu) ---\n", output.length()); + } catch (const std::exception & e) { + LOG_ERR("Rendering failed: %s\n", e.what()); + } +} + +static void render_all_scenarios(const common_chat_template & tmpl, + const json & tools, + bool add_generation_prompt, + bool enable_thinking, + input_message_type message_type) { + json user_msg = build_user_message(); + + auto render_if = [&](input_message_type type, const std::string & name, const json & assistant_msg) { + if (message_type == input_message_type::ALL || message_type == type) { + json messages = json::array({ user_msg, assistant_msg }); + render_scenario(tmpl, name, messages, tools, add_generation_prompt, enable_thinking); + } + }; + + render_if(input_message_type::CONTENT_ONLY, "content_only", build_content_only_message()); + render_if(input_message_type::REASONING_CONTENT, "reasoning_content", build_reasoning_content_message()); + render_if(input_message_type::TOOL_CALL_ONLY, "tool_call_only", build_tool_call_only_message()); + render_if(input_message_type::CONTENT_TOOL_CALL, "content_tool_call", build_content_tool_call_message()); + render_if(input_message_type::REASONING_TOOL_CALL, "reasoning_tool_call", build_reasoning_tool_call_message()); + render_if(input_message_type::CONTENT_FAKE_TOOL_CALL, "content_fake_tool_call", + build_content_fake_tool_call_message()); + + // Also render with add_generation_prompt=true to show the prompt ending + if (message_type == input_message_type::ALL) { + LOG_ERR("\n\n=== Generation Prompt Scenarios (add_generation_prompt=true) ===\n"); + + json prompt_messages = json::array({ user_msg }); + render_scenario(tmpl, "generation_prompt_only", prompt_messages, tools, true, enable_thinking); + + // With enable_thinking toggled + render_scenario(tmpl, "generation_prompt_thinking_disabled", prompt_messages, tools, true, false); + } +} + +int main(int argc, char ** argv) { + // Set log level to most verbose to capture all debug output + common_log_set_verbosity_thold(99); + + debug_options opts; + if (!parse_options(argc, argv, opts)) { + return 1; + } + + if (opts.debug_jinja || std::getenv("LLAMA_DEBUG_JINJA") != nullptr) { + jinja::enable_debug(true); + } + + std::string template_source; + try { + // Check if the file is a GGUF file + if (opts.template_path.size() >= 5 && + opts.template_path.compare(opts.template_path.size() - 5, 5, ".gguf") == 0) { + template_source = read_gguf_chat_template(opts.template_path); + } else { + template_source = read_file(opts.template_path); + } + } catch (const std::exception & e) { + LOG_ERR("Error reading template: %s\n", e.what()); + return 1; + } + + LOG_ERR("Analyzing template: %s\n", opts.template_path.c_str()); + LOG_ERR("Options: with_tools=%s, generation_prompt=%s, enable_reasoning=%s\n", opts.with_tools ? "true" : "false", + opts.generation_prompt ? "true" : "false", opts.enable_reasoning ? "true" : "false"); + + try { + common_chat_template chat_template(template_source, "", ""); + + // Build tools definition + json tools = opts.with_tools ? build_tools_definition() : json(); + + // Render template scenarios if requested + if (opts.input_message != input_message_type::NONE && + (opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) { + LOG_ERR("\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE RENDERING OUTPUT\n"); + LOG_ERR("================================================================================\n"); + + render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning, + opts.input_message); + } + + // Output analysis if requested + if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) { + LOG_ERR("\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE ANALYSIS\n"); + LOG_ERR("================================================================================\n"); + + autoparser::autoparser analysis; + analysis.analyze_template(chat_template); + + // Generate Parser + autoparser::templates_params params; + params.messages = json::array({ build_user_message() }); + params.reasoning_format = + opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE; + params.enable_thinking = opts.enable_reasoning; + params.add_generation_prompt = opts.generation_prompt; + + if (opts.with_tools) { + params.tools = tools; + params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO; + } else { + params.tools = json(); + params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + } + params.parallel_tool_calls = false; + + auto parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis); + + LOG_ERR("\n=== Generated Parser ===\n"); + common_peg_arena arena; + arena.load(parser_data.parser); + LOG_ERR("%s\n", arena.dump(arena.root()).c_str()); + + LOG_ERR("\n=== Generated Grammar ===\n"); + LOG_ERR("%s\n", parser_data.grammar.c_str()); + + LOG_ERR("\n=== Generated Lazy Grammar ===\n"); + LOG_ERR("%d\n", parser_data.grammar_lazy); + + LOG_ERR("\n=== Generated Grammar Triggers ===\n"); + for (const common_grammar_trigger & cgt : parser_data.grammar_triggers) { + LOG_ERR("Token: %d | Type: %d | Value: %s\n", cgt.token, cgt.type, cgt.value.c_str()); + } + + LOG_ERR("\n=== Preserved Tokens ===\n"); + for (const std::string & token : parser_data.preserved_tokens) { + LOG_ERR(" '%s'\n", token.c_str()); + } + + if (!parser_data.grammar.empty()) { + LOG_ERR("\n=== Verifying created grammar ===\n"); + auto * grammar = llama_grammar_init_impl(nullptr, parser_data.grammar.c_str(), "root", + parser_data.grammar_lazy, nullptr, 0, nullptr, 0); + if (grammar != nullptr) { + LOG_ERR("\n=== Grammar successfully created ===\n"); + } + } + } + } catch (const std::exception & e) { + LOG_ERR("Analysis failed: %s\n", e.what()); + return 1; + } + + return 0; +} diff --git a/tools/parser/template-analysis.cpp b/tools/parser/template-analysis.cpp new file mode 100644 index 000000000..a92e104ac --- /dev/null +++ b/tools/parser/template-analysis.cpp @@ -0,0 +1,611 @@ +#include "chat-auto-parser.h" +#include "chat-auto-parser-helpers.h" +#include "chat.h" +#include "log.h" +#include "jinja/caps.h" +#include "jinja/runtime.h" + +#include +#include +#include +#include +#include + +#include "nlohmann/json.hpp" + +using json = nlohmann::ordered_json; + +// ANSI color codes - using 256-color palette for brighter colors (all bold) +#define ANSI_RESET "\033[0m" +#define ANSI_PURPLE "\033[1m\x1b[38;5;126m" // Bold bright purple for main headers +#define ANSI_CYAN "\033[1m\x1b[38;5;81m" // Bold bright cyan for section headers +#define ANSI_BLUE "\033[1m\x1b[38;5;12m" // Bold bright blue for labels +#define ANSI_ORANGE "\033[1m\x1b[38;5;209m" // Bold orange for right differences +#define ANSI_GREEN "\033[1m\x1b[38;5;83m" // Bold bright green for left differences +#define ANSI_GRAY "\033[1m\x1b[38;5;240m" // Bold gray (used for "no variables" message) +#define ANSI_BOLD "\033[1m" // Standalone bold +#define ANSI_PREFIX "\033[1m\x1b[38;5;176m" // Bold color for common prefix +#define ANSI_SUFFIX "\033[1m\x1b[38;5;61m" // Bold color for common suffix + +// All template paths extracted from tests/test-chat.cpp +static const std::vector ALL_TEMPLATE_PATHS = { + "models/templates/Apertus-8B-Instruct.jinja", + "models/templates/Apriel-1.6-15b-Thinker-fixed.jinja", + "models/templates/ByteDance-Seed-OSS.jinja", + "models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", + "models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja", + "models/templates/GLM-4.6.jinja", + "models/templates/GLM-4.7-Flash.jinja", + "models/templates/Kimi-K2-Instruct.jinja", + "models/templates/Kimi-K2-Thinking.jinja", + "models/templates/MiMo-VL.jinja", + "models/templates/MiniMax-M2.jinja", + "models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", + "models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", + "models/templates/NVIDIA-Nemotron-Nano-v2.jinja", + "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", + "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", + "models/templates/Qwen-QwQ-32B.jinja", + "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja", + "models/templates/Qwen3-Coder.jinja", + "models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", + "models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", + "models/templates/deepseek-ai-DeepSeek-V3.1.jinja", + "models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", + "models/templates/google-gemma-2-2b-it.jinja", + "models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja", + "models/templates/llama-cpp-deepseek-r1.jinja", + "models/templates/meetkai-functionary-medium-v3.1.jinja", + "models/templates/meetkai-functionary-medium-v3.2.jinja", + "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", + "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", + "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", + "models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", + "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", + "models/templates/moonshotai-Kimi-K2.jinja", + "models/templates/openai-gpt-oss-120b.jinja", + "models/templates/unsloth-Apriel-1.5.jinja", + "models/templates/unsloth-mistral-Devstral-Small-2507.jinja", +}; + +struct analysis_options { + std::vector template_paths; + bool analyze_all = false; +}; + +static std::string read_file(const std::string & path) { + std::ifstream fin(path, std::ios::binary); + if (!fin.is_open()) { + throw std::runtime_error("Could not open file: " + path); + } + std::ostringstream buf; + buf << fin.rdbuf(); + return buf.str(); +} + +static void print_usage(const char * program_name) { + LOG_ERR("Usage: %s [options]\n", program_name); + LOG_ERR("\nOptions:\n"); + LOG_ERR(" --template Analyze specific template from test suite (e.g., 'deepseek' or 'DeepSeek-V3.1')\n"); + LOG_ERR(" --template-file Analyze custom template file\n"); + LOG_ERR(" --all Analyze all templates from test suite\n"); + LOG_ERR("\nExamples:\n"); + LOG_ERR(" %s --all\n", program_name); + LOG_ERR(" %s --template deepseek\n", program_name); + LOG_ERR(" %s --template-file my-template.jinja\n", program_name); +} + +static bool parse_options(int argc, char ** argv, analysis_options & opts) { + if (argc < 2) { + print_usage(argv[0]); + return false; + } + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + + if (arg == "--all") { + opts.analyze_all = true; + } else if (arg == "--template") { + if (i + 1 >= argc) { + LOG_ERR("--template requires an argument\n"); + return false; + } + std::string pattern = argv[++i]; + std::transform(pattern.begin(), pattern.end(), pattern.begin(), ::tolower); + + // Find matching templates + bool found = false; + for (const auto & path : ALL_TEMPLATE_PATHS) { + std::string path_lower = path; + std::transform(path_lower.begin(), path_lower.end(), path_lower.begin(), ::tolower); + if (path_lower.find(pattern) != std::string::npos) { + opts.template_paths.push_back(path); + found = true; + } + } + + if (!found) { + LOG_ERR("No templates found matching: %s\n", pattern.c_str()); + return false; + } + } else if (arg == "--template-file") { + if (i + 1 >= argc) { + LOG_ERR("--template-file requires an argument\n"); + return false; + } + opts.template_paths.push_back(argv[++i]); + } else { + LOG_ERR("Unknown option: %s\n", arg.c_str()); + print_usage(argv[0]); + return false; + } + } + + if (opts.analyze_all) { + opts.template_paths = ALL_TEMPLATE_PATHS; + } + + if (opts.template_paths.empty()) { + LOG_ERR("No templates specified\n"); + print_usage(argv[0]); + return false; + } + + return true; +} + +static json build_tools_definition() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["param1"] = json::object({ + { "type", "string" }, + { "description", "First parameter" } + }); + parameters_schema["properties"]["param2"] = json::object({ + { "type", "string" }, + { "description", "Second parameter" } + }); + parameters_schema["required"] = json::array({ "param1", "param2" }); + + return json::array({ + json{ { "type", "function" }, + { "function", json{ { "name", "test_function_name" }, + { "description", "A test function for debugging" }, + { "parameters", parameters_schema } } } } + }); +} + +// Helper to create a tool call with arguments as JSON object +static json build_tool_call(const std::string & name, const json & args_object, const std::string & id = "call_001") { + return json{ + {"id", id}, + {"type", "function"}, + {"function", json{ + {"name", name}, + {"arguments", args_object} // Pass as JSON object, not serialized string + }} + }; +} + +// Helper functions to create repeating message definitions +static json make_user_msg() { + return json{ + {"role", "user"}, + {"content", "Hello, please help me."} + }; +} + +static json make_user_msg2() { + return json{ + {"role", "user"}, + {"content", "Thank you."} + }; +} + +static json make_user_msg2_continue() { + return json{ + {"role", "user"}, + {"content", "Continue."} + }; +} + +static json make_assistant_no_tool() { + return json{ + {"role", "assistant"}, + {"content", "Let me help you."} + }; +} + +static json make_assistant_one_tool() { + return json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })} + }; +} + +static json make_assistant_two_tools() { + return json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})), + build_tool_call("test_function_name", json::object({{"param1", "value3"}, {"param2", "value4"}}), "call_002") + })} + }; +} + +static json make_assistant_no_reasoning() { + return json{ + {"role", "assistant"}, + {"content", "I can help you with that."} + }; +} + +static json make_assistant_with_reasoning() { + return json{ + {"role", "assistant"}, + {"content", "I can help you with that."}, + {"reasoning_content", "The user is asking for help. I should respond positively."} + }; +} + +static json make_assistant_one_tool_with_reasoning() { + return json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })}, + {"reasoning_content", "I need to call the tool first."} + }; +} + +static void print_diff_split(const std::string & title, const diff_split & diff) { + LOG_ERR("\n%s=== %s ===%s\n", ANSI_CYAN, title.c_str(), ANSI_RESET); + LOG_ERR("%sCommon Prefix:%s '%s'\n", ANSI_PREFIX, ANSI_RESET, diff.prefix.c_str()); + LOG_ERR("%sCommon Suffix:%s '%s'\n", ANSI_SUFFIX, ANSI_RESET, diff.suffix.c_str()); + LOG_ERR("%sLeft (difference):%s '%s'\n", ANSI_GREEN, ANSI_RESET, diff.left.c_str()); + LOG_ERR("%sRight (difference):%s '%s'\n", ANSI_ORANGE, ANSI_RESET, diff.right.c_str()); +} + +static void check_reasoning_variables(const common_chat_template & tmpl) { + LOG_ERR("\n%s=== Checking Reasoning Variables ===%s\n", ANSI_CYAN, ANSI_RESET); + + try { + // Create a list of candidate reasoning/thinking variable names to probe + std::vector candidate_vars = { + "enable_reasoning", + "use_reasoning", + "reasoning_enabled", + "has_reasoning", + "reasoning_mode", + "reasoning_format", + "reasoning_active", + "with_reasoning", + "use_thinking", + "thinking_enabled", + "has_thinking", + "thinking_mode", + "thinking_format", + "thinking_active", + "with_thinking", + "enable_reason", + "reason_enabled", + "enable_think", + "think_enabled", + }; + + jinja::context ctx; + ctx.is_get_stats = true; + + json messages = json::array({ + json{ + {"role", "user"}, + {"content", "Test message"} + }, + json{ + {"role", "assistant"}, + {"content", "Response"}, + {"reasoning_content", "Some reasoning"} + } + }); + + // Set up base context + jinja::global_from_json(ctx, json{ + {"messages", messages}, + {"tools", json::array()}, + {"bos_token", ""}, + {"eos_token", ""}, + {"add_generation_prompt", false}, + {"enable_thinking", true} // Already passed, so we'll exclude this from results + }, true); + + // Add candidate variables as undefined to probe which ones are accessed + for (const auto & var_name : candidate_vars) { + ctx.set_val(var_name, jinja::mk_val(var_name)); + } + + try { + jinja::runtime runtime(ctx); + runtime.execute(tmpl.prog); + } catch (const std::exception & e) { + // Execution may fail, that's okay - we just want to see what variables were accessed + } + + // Check which candidate variables were accessed (stats.used = true) + std::vector accessed_vars; + for (const auto & var_name : candidate_vars) { + auto val = ctx.get_val(var_name); + if (!val->is_undefined()) { + // Variable was overwritten, skip it + continue; + } + if (val->stats.used) { + accessed_vars.push_back(var_name); + } + } + + if (accessed_vars.empty()) { + LOG_ERR("%sNo reasoning/thinking-related variables were queried by the template%s\n", ANSI_GRAY, ANSI_RESET); + } else { + LOG_ERR("Template queries the following reasoning/thinking-related variables:\n"); + for (const auto & var : accessed_vars) { + LOG_ERR(" %s- %s%s\n", ANSI_ORANGE, var.c_str(), ANSI_RESET); + } + } + + } catch (const std::exception & e) { + LOG_ERR("Error checking reasoning variables: %s\n", e.what()); + } +} + +static void analyze_template(const std::string & template_path) { + LOG_ERR("\n"); + LOG_ERR("%s", ANSI_PURPLE); + LOG_ERR("================================================================================\n"); + LOG_ERR(" ANALYZING TEMPLATE: %s\n", template_path.c_str()); + LOG_ERR("================================================================================\n"); + LOG_ERR("%s", ANSI_RESET); + + std::string template_source; + try { + template_source = read_file(template_path); + } catch (const std::exception & e) { + LOG_ERR("Error reading template: %s\n", e.what()); + return; + } + + try { + common_chat_template chat_template(template_source, "", ""); + json tools = build_tools_definition(); + + // ===== CAPABILITIES ANALYSIS ===== + LOG_ERR("\n%s=== Template Capabilities (from jinja::caps) ===%s\n", ANSI_CYAN, ANSI_RESET); + auto caps = chat_template.original_caps(); + LOG_ERR("%ssupports_tools:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_tools ? "true" : "false"); + LOG_ERR("%ssupports_tool_calls:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_tool_calls ? "true" : "false"); + LOG_ERR("%ssupports_system_role:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_system_role ? "true" : "false"); + LOG_ERR("%ssupports_parallel_tool_calls:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_parallel_tool_calls ? "true" : "false"); + LOG_ERR("%ssupports_typed_content:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_typed_content ? "true" : "false"); + LOG_ERR("%ssupports_string_content:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_string_content ? "true" : "false"); + + // ===== DIFFERENTIAL ANALYSIS ===== + + // Test 1: With and without tools (single user message) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_tools; + params_no_tools.messages = json::array({ user_msg }); + params_no_tools.add_generation_prompt = false; + params_no_tools.tools = json::array(); + + autoparser::templates_params params_with_tools = params_no_tools; + params_with_tools.tools = tools; + + std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools); + std::string output_with_tools = common_chat_template_direct_apply(chat_template, params_with_tools); + + auto diff = calculate_diff_split(output_no_tools, output_with_tools); + print_diff_split("Diff: With vs Without Tools (single user message)", diff); + } + + // Test 2: With and without add_generation_prompt (single user message) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_prompt; + params_no_prompt.messages = json::array({ user_msg }); + params_no_prompt.add_generation_prompt = false; + params_no_prompt.tools = json::array(); + + autoparser::templates_params params_with_prompt = params_no_prompt; + params_with_prompt.add_generation_prompt = true; + + std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt); + std::string output_with_prompt = common_chat_template_direct_apply(chat_template, params_with_prompt); + + auto diff = calculate_diff_split(output_no_prompt, output_with_prompt); + print_diff_split("Diff: With vs Without add_generation_prompt (single user message)", diff); + } + + // Test 3: Assistant with reasoning_content (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_reasoning; + params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() }); + params_no_reasoning.add_generation_prompt = false; + params_no_reasoning.enable_thinking = true; + + autoparser::templates_params params_with_reasoning = params_no_reasoning; + params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() }); + + std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); + std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning); + + auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning); + print_diff_split("Diff: With vs Without reasoning_content (user, assistant)", diff); + } + + // Test 4: Assistant with reasoning_content (user, assistant, user) + { + json user_msg = make_user_msg(); + json user_msg2 = make_user_msg2(); + + autoparser::templates_params params_no_reasoning; + params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 }); + params_no_reasoning.add_generation_prompt = false; + params_no_reasoning.enable_thinking = true; + + autoparser::templates_params params_with_reasoning = params_no_reasoning; + params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 }); + + std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); + std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning); + + auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning); + print_diff_split("Diff: With vs Without reasoning_content (user, assistant, user)", diff); + } + + // Test 5: Tool call in last assistant message (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_tool; + params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() }); + params_no_tool.add_generation_prompt = false; + params_no_tool.tools = tools; + + autoparser::templates_params params_with_tool = params_no_tool; + params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); + + std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); + std::string output_with_tool = common_chat_template_direct_apply(chat_template, params_with_tool); + + auto diff = calculate_diff_split(output_no_tool, output_with_tool); + print_diff_split("Diff: With vs Without tool call (user, assistant)", diff); + } + + // Test 6: Tool call in last assistant message (user, assistant, user) + { + json user_msg = make_user_msg(); + json user_msg2 = make_user_msg2_continue(); + + autoparser::templates_params params_no_tool; + params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 }); + params_no_tool.add_generation_prompt = false; + params_no_tool.tools = tools; + + autoparser::templates_params params_with_tool = params_no_tool; + params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); + + std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); + std::string output_with_tool = common_chat_template_direct_apply(chat_template, params_with_tool); + + auto diff = calculate_diff_split(output_no_tool, output_with_tool); + print_diff_split("Diff: With vs Without tool call (user, assistant, user)", diff); + } + + // Test 7: One vs two tool calls (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_one_tool; + params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); + params_one_tool.add_generation_prompt = false; + params_one_tool.tools = tools; + + autoparser::templates_params params_two_tools = params_one_tool; + params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() }); + + std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); + std::string output_two_tools = common_chat_template_direct_apply(chat_template, params_two_tools); + + auto diff = calculate_diff_split(output_one_tool, output_two_tools); + print_diff_split("Diff: One vs Two tool calls (user, assistant)", diff); + } + + // Test 8: One vs two tool calls (user, assistant, user) + { + json user_msg = make_user_msg(); + json user_msg2 = make_user_msg2_continue(); + + autoparser::templates_params params_one_tool; + params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); + params_one_tool.add_generation_prompt = false; + params_one_tool.tools = tools; + + autoparser::templates_params params_two_tools = params_one_tool; + params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 }); + + std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); + std::string output_two_tools = common_chat_template_direct_apply(chat_template, params_two_tools); + + auto diff = calculate_diff_split(output_one_tool, output_two_tools); + print_diff_split("Diff: One vs Two tool calls (user, assistant, user)", diff); + } + + // Test 9: Tool call with vs without reasoning_content (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_reasoning; + params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() }); + params_no_reasoning.add_generation_prompt = false; + params_no_reasoning.tools = tools; + params_no_reasoning.enable_thinking = true; + + autoparser::templates_params params_with_reasoning = params_no_reasoning; + params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() }); + + std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); + std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning); + + auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning); + print_diff_split("Diff: Tool call with vs without reasoning_content (user, assistant)", diff); + } + + // Check reasoning variables + check_reasoning_variables(chat_template); + + } catch (const std::exception & e) { + LOG_ERR("Analysis failed: %s\n", e.what()); + } +} + +int main(int argc, char ** argv) { + // Set log level to capture all output + common_log_set_verbosity_thold(99); + + analysis_options opts; + if (!parse_options(argc, argv, opts)) { + return 1; + } + + LOG_ERR("\n"); + LOG_ERR("%s", ANSI_PURPLE); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE ANALYSIS TOOL\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR("%s", ANSI_RESET); + LOG_ERR("Analyzing %s%zu%s template(s)\n", ANSI_CYAN, opts.template_paths.size(), ANSI_RESET); + + for (const auto & path : opts.template_paths) { + analyze_template(path); + } + + LOG_ERR("\n"); + LOG_ERR("%s", ANSI_GREEN); + LOG_ERR("================================================================================\n"); + LOG_ERR(" ANALYSIS COMPLETE\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR("%s", ANSI_RESET); + + return 0; +} diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 38576c45f..bb25887a1 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -729,6 +729,10 @@ export class SchemaConverter { return this._addRule(ruleName, out.join('')); } else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) { return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object'])); + } else if (schemaType === undefined && typeof schema === 'object' && !Array.isArray(schema) && schema !== null) { + // No type constraint and no recognized structural keywords (e.g. {"description": "..."}). + // Per JSON Schema semantics this is equivalent to {} and accepts any value. + return this._addRule(ruleName, this._addPrimitive('value', PRIMITIVE_RULES['value'])); } else { if (!(schemaType in PRIMITIVE_RULES)) { throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index d3aba1848..32c0d8f48 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1,12 +1,12 @@ -#include "server-common.h" #include "server-task.h" -#include "common.h" -#include "llama.h" #include "chat.h" +#include "common.h" +#include "json-schema-to-grammar.h" +#include "llama.h" #include "sampling.h" #include "speculative.h" -#include "json-schema-to-grammar.h" +#include "server-common.h" using json = nlohmann::ordered_json; @@ -157,7 +157,8 @@ json task_params::to_json(bool only_metrics) const { common_chat_msg task_result_state::update_chat_msg( const std::string & text_added, bool is_partial, - std::vector & diffs) { + std::vector & diffs, + bool filter_tool_calls) { generated_text += text_added; auto msg_prv_copy = chat_msg; SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); @@ -168,7 +169,64 @@ common_chat_msg task_result_state::update_chat_msg( if (!new_msg.empty()) { new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg); + auto all_diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, chat_msg); + + if (!filter_tool_calls) { + diffs = std::move(all_diffs); + } else { + for (auto & d : all_diffs) { + // If this is a new type of delta, flush all currently pending tool call names + for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) { + if (sent_tool_call_names.count(i) || chat_msg.tool_calls[i].name.empty()) { + continue; + } + if (d.tool_call_index != i || !d.tool_call_delta.arguments.empty()) { + common_chat_msg_diff header; + header.tool_call_index = i; + header.tool_call_delta.id = chat_msg.tool_calls[i].id; + header.tool_call_delta.name = chat_msg.tool_calls[i].name; + diffs.push_back(std::move(header)); + sent_tool_call_names.insert(i); + } + } + + if (d.tool_call_index == std::string::npos) { + diffs.push_back(std::move(d)); + } else { + size_t i = d.tool_call_index; + if (sent_tool_call_names.count(i)) { + if (!d.tool_call_delta.arguments.empty()) { + d.tool_call_delta.name = ""; + d.tool_call_delta.id = ""; + diffs.push_back(std::move(d)); + } + } else { + // Not sent yet. + if (!d.tool_call_delta.arguments.empty() || !is_partial) { + d.tool_call_delta.name = chat_msg.tool_calls[i].name; + d.tool_call_delta.id = chat_msg.tool_calls[i].id; + diffs.push_back(std::move(d)); + sent_tool_call_names.insert(i); + } else { + // Suppress + } + } + } + } + // Final check at EOF + if (!is_partial) { + for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) { + if (!sent_tool_call_names.count(i) && !chat_msg.tool_calls[i].name.empty()) { + common_chat_msg_diff header; + header.tool_call_index = i; + header.tool_call_delta.id = chat_msg.tool_calls[i].id; + header.tool_call_delta.name = chat_msg.tool_calls[i].name; + diffs.push_back(std::move(header)); + sent_tool_call_names.insert(i); + } + } + } + } } return chat_msg; } diff --git a/tools/server/server-task.h b/tools/server/server-task.h index e2e3e5a58..1e342531d 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -98,6 +98,7 @@ struct task_result_state { common_chat_msg chat_msg; std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; + std::unordered_set sent_tool_call_names; // for OpenAI Responses and Anthropic streaming API: // track output item / content block state across chunks @@ -120,7 +121,8 @@ struct task_result_state { common_chat_msg update_chat_msg( const std::string & text_added, bool is_partial, - std::vector & diffs); + std::vector & diffs, + bool filter_tool_calls = false); }; struct server_task { diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index b8f0f1086..ba41cd44e 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -100,18 +100,19 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] - assert expected_function_name == tool_call["function"]["name"] + assert expected_function_name == tool_call["function"]["name"], f'Expected tool name to be {tool_call["function"]["name"]} in {choice["message"]}' actual_arguments = tool_call["function"]["arguments"] - assert isinstance(actual_arguments, str) + assert isinstance(actual_arguments, dict) or isinstance(actual_arguments, str), f'Expected arguments to be a dict or str, got: {actual_arguments}' if argument_key is not None: - actual_arguments = json.loads(actual_arguments) - assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + if (isinstance(actual_arguments, str)): + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {actual_arguments}, expected: {argument_key}" @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) @pytest.mark.parametrize("template_name,tool,argument_key", [ - ("google-gemma-2-2b-it", TEST_TOOL, "success"), - ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("Qwen3-Coder", TEST_TOOL, "success"), + ("Qwen3-Coder", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),