diff --git a/common/arg.cpp b/common/arg.cpp index fdbd4f9b7..3a26f0161 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -540,9 +540,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // } catch (const std::exception & e) { // LOG_WRN("HF cache migration failed: %s\n", e.what()); // } + // export_graph_ops loads only metadata + const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; // maybe handle remote preset - if (!params.model.hf_repo.empty()) { + if (!params.model.hf_repo.empty() && !skip_model_download) { std::string cli_hf_repo = params.model.hf_repo; bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex); @@ -573,7 +575,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // handle model and download - { + if (!skip_model_download) { auto res = common_params_handle_model(params.model, params.hf_token, params.offline); if (params.no_mmproj) { params.mmproj = {}; @@ -594,7 +596,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case - if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) { + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) { throw std::invalid_argument("error: --model is required\n"); } diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 3f036bb5b..60b269c42 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -7,11 +7,109 @@ #include "log.h" #include "nlohmann/json.hpp" +#include #include #include using json = nlohmann::ordered_json; +namespace { + +// Gemma4-specific PEG builder extending the standard chat builder. +// Adds value type parsers that use <|\"|> as string delimiters +// instead of JSON's double quotes, and disables json-to-schema +// conversion for these types. +class common_peg_gemma4_builder { + common_chat_peg_builder & p_; + static constexpr const char * QUOTE = "<|\"|>"; + +public: + explicit common_peg_gemma4_builder(common_chat_peg_builder & p) : p_(p) {} + + common_peg_parser gemma4_string() { + return p_.rule("gemma4-string", [&]() { + return p_.literal(QUOTE) + p_.until(QUOTE) + p_.literal(QUOTE); + }); + } + + common_peg_parser gemma4_number() { + return p_.rule("gemma4-number", [&]() { + auto digit1_9 = p_.chars("[1-9]", 1, 1); + auto digits = p_.chars("[0-9]"); + auto int_part = p_.choice({p_.literal("0"), p_.sequence({digit1_9, p_.chars("[0-9]", 0, -1)})}); + auto frac = p_.sequence({p_.literal("."), digits}); + auto exp = p_.sequence({p_.choice({p_.literal("e"), p_.literal("E")}), + p_.optional(p_.chars("[+-]", 1, 1)), digits}); + auto not_number_continuation = p_.negate(p_.chars("[0-9.eE+-]", 1, 1)); + return p_.sequence({p_.optional(p_.literal("-")), int_part, p_.optional(frac), + p_.optional(exp), not_number_continuation}); + }); + } + + common_peg_parser gemma4_bool() { + return p_.rule("gemma4-bool", [&]() { + return p_.choice({p_.literal("true"), p_.literal("false")}); + }); + } + + common_peg_parser gemma4_null() { + return p_.rule("gemma4-null", [&]() { + return p_.literal("null"); + }); + } + + common_peg_parser gemma4_dict() { + return p_.rule("gemma4-dict", [&]() { + auto ws = p_.space(); + auto key = p_.until(":"); + auto member = p_.sequence({key, p_.literal(":"), ws, gemma4_value()}); + auto members = p_.sequence({member, p_.zero_or_more(p_.sequence({p_.literal(","), ws, member}))}); + return p_.sequence({ + p_.literal("{"), ws, + p_.choice({p_.literal("}"), p_.sequence({members, ws, p_.literal("}")})}) + }); + }); + } + + common_peg_parser gemma4_array() { + return p_.rule("gemma4-array", [&]() { + auto ws = p_.space(); + auto elements = p_.sequence({gemma4_value(), p_.zero_or_more(p_.sequence({p_.literal(","), ws, gemma4_value()}))}); + return p_.sequence({ + p_.literal("["), ws, + p_.choice({p_.literal("]"), p_.sequence({elements, ws, p_.literal("]")})}) + }); + }); + } + + common_peg_parser gemma4_value() { + return p_.rule("gemma4-value", [&]() { + return p_.choice({gemma4_string(), gemma4_dict(), gemma4_array(), + gemma4_number(), gemma4_bool(), gemma4_null()}); + }); + } + + // Select the appropriate value parser based on JSON schema type. + // Does NOT use schema() - the gemma4 types are pure PEG without + // JSON schema metadata, so GBNF is generated directly from the + // PEG structure. + common_peg_parser gemma4_value_for_type(const json & schema) { + if (!schema.contains("type") || !schema.at("type").is_string()) { + return gemma4_value(); + } + std::string type = schema.at("type").get(); + if (type == "string") { return gemma4_string(); } + if (type == "number") { return gemma4_number(); } + if (type == "integer") { return gemma4_number(); } + if (type == "boolean") { return gemma4_bool(); } + if (type == "object") { return gemma4_dict(); } + if (type == "array") { return gemma4_array(); } + return gemma4_value(); + } +}; + +} // anonymous namespace + // Helper to iterate over tools/functions static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { @@ -43,7 +141,9 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & // Create the result structure common_chat_params data; data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.format = (autoparser.tools.format.mode == tool_format::TAG_WITH_GEMMA4_DICT) + ? COMMON_CHAT_FORMAT_PEG_GEMMA4 + : COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = autoparser.preserved_tokens; auto parser = autoparser.build_parser(inputs); @@ -92,6 +192,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE; ctx.content = &content; + ctx.reasoning = &reasoning; // Build reasoning parser ctx.reasoning_parser = reasoning.build_parser(ctx); @@ -100,6 +201,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons bool has_tools = inputs.tools.is_array() && !inputs.tools.empty(); bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty(); + bool pure_content = reasoning.mode == reasoning_mode::NONE; if (has_response_format) { auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); @@ -107,12 +209,14 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"), response_format }) + p.end(); + pure_content = false; } else if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) { parser = tools.build_parser(ctx); + pure_content = false; } else { parser = content.build_parser(ctx); } - return p.prefix(inputs.generation_prompt, reasoning.start) + parser; + return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser; }); } @@ -166,6 +270,8 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const return build_tool_parser_tag_json(ctx); case tool_format::TAG_WITH_TAGGED: return build_tool_parser_tag_tagged(ctx); + case tool_format::TAG_WITH_GEMMA4_DICT: + return build_tool_parser_tag_gemma4_dict(ctx); default: LOG_ERR("[ERROR] Template seems to support tool calls, but failed to determine tool format. Tool calling will not work properly. " "Check for a fixed template for your model in the models/templates directory of your llama.cpp installation or " @@ -430,4 +536,121 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte p.end(); } +common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_gemma4_builder g4(p); + static const std::string QUOTE = "<|\"|>"; + + common_peg_parser tool_choice = p.choice(); + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & func = tool.at("function"); + std::string name = func.at("name"); + const auto & params = func.at("parameters"); + + if (!params.contains("properties") || !params.at("properties").is_object()) { + auto func_parser = p.atomic( + p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) + + p.tool_args(p.eps()) + + p.tool_close(p.literal("}"))); + tool_choice |= p.rule("tool-" + name, func_parser); + return; + } + + const auto & properties = params.at("properties"); + std::set required; + if (params.contains("required") && params.at("required").is_array()) { + params.at("required").get_to(required); + } + + // Build per-argument parsers, sorted alphabetically (matching template's dictsort) + struct arg_entry { + std::string param_name; + common_peg_parser parser; + }; + std::vector arg_entries; + + for (const auto & [param_name, param_schema] : properties.items()) { + std::string type = "object"; + auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object(); + if (type_v.is_string()) type_v.get_to(type); + + common_peg_parser value_parser = p.eps(); + if (type == "string") { + // String values are delimited by <|"|>...<|"|> + value_parser = + p.literal(QUOTE) + + p.tool_arg_string_value(p.schema(p.until(QUOTE), + "tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) + + p.literal(QUOTE); + } else if (type == "number" || type == "integer") { + value_parser = p.tool_arg_value(g4.gemma4_number()); + } else if (type == "boolean") { + value_parser = p.tool_arg_value(g4.gemma4_bool()); + } else if (type == "null") { + value_parser = p.tool_arg_value(g4.gemma4_null()); + } else if (type == "object") { + value_parser = p.tool_arg_value(g4.gemma4_dict()); + } else if (type == "array") { + value_parser = p.tool_arg_value(g4.gemma4_array()); + } else { + value_parser = p.tool_arg_value(g4.gemma4_value()); + } + + auto arg = p.tool_arg( + p.tool_arg_open(p.tool_arg_name(p.literal(param_name)) + p.literal(":")) + + value_parser + + p.tool_arg_close(p.eps())); + + arg_entries.push_back({param_name, p.rule("tool-" + name + "-arg-" + param_name, arg)}); + } + + // Sort alphabetically to match Jinja's dictsort + std::sort(arg_entries.begin(), arg_entries.end(), [](const auto & a, const auto & b) { + return a.param_name < b.param_name; + }); + + // Build arg sequence: any arg, then zero-or-more comma-separated additional args + common_peg_parser args_seq = p.eps(); + if (!arg_entries.empty()) { + common_peg_parser any_arg = p.choice(); + for (auto & entry : arg_entries) { + any_arg |= entry.parser; + } + args_seq = p.optional( + any_arg + p.repeat(p.literal(",") + any_arg, 0, (int) arg_entries.size() - 1)); + } + + // Full parser: call:name{args} + auto func_parser = p.atomic( + p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) + + p.tool_args(args_seq) + + p.tool_close(p.literal("}"))); + + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + // Wrap each call in <|tool_call>... + auto wrapped_call = p.literal(format.per_call_start) + tool_choice + p.literal(format.per_call_end); + + common_peg_parser tool_calls = p.eps(); + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + } else { + tool_calls = p.trigger_rule("tool-call", wrapped_call); + } + + if (!force_tools) { + tool_calls = p.optional(tool_calls); + } + + auto content_before_tools = p.until_one_of({ format.per_call_start, ctx.reasoning->start }); + return ctx.reasoning_parser + + (force_tools ? p.eps() : p.optional(p.content(content_before_tools) + p.optional(ctx.reasoning_parser))) + + tool_calls + p.end(); +} + } // namespace autoparser diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 73888276f..9d7d4e69e 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -144,6 +144,7 @@ enum class tool_format { JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}} TAG_WITH_JSON, // Tag-based with JSON args: {...} TAG_WITH_TAGGED, // Tag-based with tagged args: value + TAG_WITH_GEMMA4_DICT, // Gemma4 custom dict: <|tool_call>call:name{key:<|"|>val<|"|>} }; inline std::ostream & operator<<(std::ostream & os, const tool_format & format) { @@ -156,6 +157,8 @@ inline std::ostream & operator<<(std::ostream & os, const tool_format & format) return os << "TAG_WITH_JSON"; case tool_format::TAG_WITH_TAGGED: return os << "TAG_WITH_TAGGED"; + case tool_format::TAG_WITH_GEMMA4_DICT: + return os << "TAG_WITH_GEMMA4_DICT"; default: return os << "UNKNOWN"; } @@ -212,12 +215,14 @@ struct tool_id_analysis { // ============================================================================ struct analyze_content; +struct analyze_reasoning; struct parser_build_context { common_chat_peg_builder & p; - const generation_params & inputs; + const generation_params & inputs; common_peg_parser reasoning_parser; bool extracting_reasoning = false; + const analyze_reasoning * reasoning = nullptr; const analyze_content * content = nullptr; parser_build_context(common_chat_peg_builder & p, const generation_params & inputs); @@ -350,6 +355,7 @@ struct analyze_tools : analyze_base { common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const; common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const; common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const; + common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const; }; // ============================================================================ diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index 414ee892f..aadade60f 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -92,6 +92,34 @@ static std::vectorcall:name{key:<|"|>val<|"|>} + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("'<|tool_call>call:'") != std::string::npos) { + analysis.tools.format.mode = tool_format::TAG_WITH_GEMMA4_DICT; + analysis.tools.format.per_call_start = "<|tool_call>"; + analysis.tools.format.per_call_end = ""; + analysis.tools.format.section_start = ""; + analysis.tools.format.section_end = ""; + analysis.tools.function.name_prefix = "call:"; + analysis.tools.function.name_suffix = ""; + analysis.tools.arguments.start = "{"; + analysis.tools.arguments.end = "}"; + analysis.tools.arguments.name_prefix = ""; + analysis.tools.arguments.name_suffix = ":"; + analysis.tools.arguments.separator = ","; + analysis.reasoning.mode = reasoning_mode::TAG_BASED; + analysis.reasoning.start = "<|channel>thought"; + analysis.reasoning.end = ""; + analysis.preserved_tokens.clear(); + analysis.preserved_tokens.push_back("<|tool_call>"); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back("<|tool_response>"); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back("<|\"|>"); + analysis.preserved_tokens.push_back("<|turn>"); + LOG_DBG(ANSI_ORANGE "[Patch: Gemma4]\n" ANSI_RESET); + } + }, // DeepSeek-R1-Distill-Qwen [](const common_chat_template & tmpl, autoparser & analysis) -> void { if (tmpl.src.find( diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 07b487e15..f2ed77c44 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -75,6 +75,84 @@ static std::string escape_json_string_inner(const std::string & s) { return escaped; } +static const std::string GEMMA4_QUOTE = "<|\"|>"; + +static std::string normalize_gemma4_to_json(const std::string & input) { + std::string result; + result.reserve(input.size() * 2); + + enum Ctx { DICT, ARRAY }; + std::vector ctx; + + auto is_ws = [](char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\r'; }; + auto skip_ws = [&](size_t & pos) { + while (pos < input.size() && is_ws(input[pos])) { + result += input[pos++]; + } + }; + + auto quote_unquoted_key = [&](size_t & pos) { + if (pos < input.size() && input[pos] != '"' && input[pos] != '}') { + result += '"'; + while (pos < input.size() && input[pos] != ':' && !is_ws(input[pos])) { + result += input[pos++]; + } + result += '"'; + skip_ws(pos); + } + }; + + size_t i = 0; + while (i < input.size()) { + if (i + GEMMA4_QUOTE.size() <= input.size() && + input.compare(i, GEMMA4_QUOTE.size(), GEMMA4_QUOTE) == 0) { + result += '"'; + i += GEMMA4_QUOTE.size(); + continue; + } + + char c = input[i]; + + if (c == '{') { + result += c; + ctx.push_back(DICT); + ++i; + skip_ws(i); + quote_unquoted_key(i); + continue; + } + if (c == '}') { + result += c; + if (!ctx.empty()) ctx.pop_back(); + ++i; + continue; + } + if (c == '[') { + result += c; + ctx.push_back(ARRAY); + ++i; + continue; + } + if (c == ']') { + result += c; + if (!ctx.empty()) ctx.pop_back(); + ++i; + continue; + } + if (c == ',' && !ctx.empty() && ctx.back() == DICT) { + result += c; + ++i; + skip_ws(i); + quote_unquoted_key(i); + continue; + } + + result += c; + ++i; + } + return result; +} + // Convert Python-style single-quoted strings to JSON double-quoted strings // Only converts outer string delimiters, properly handling escape sequences: // - {'key': 'value'} -> {"key": "value"} @@ -214,6 +292,14 @@ std::string & common_chat_peg_mapper::args_target() { return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer; } +std::string common_chat_peg_mapper::normalize_container_value(const std::string & input) { + return normalize_quotes_to_json(input); +} + +std::string common_chat_peg_gemma4_mapper::normalize_container_value(const std::string & input) { + return normalize_quotes_to_json(normalize_gemma4_to_json(input)); +} + void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & parse_result_arg) { arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); }); @@ -352,7 +438,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) { // For potential containers, normalize Python-style single quotes to JSON double quotes bool is_potential_container = value_content[0] == '[' || value_content[0] == '{'; if (is_potential_container) { - value_content = normalize_quotes_to_json(value_content); + value_content = normalize_container_value(value_content); } // Try to parse as JSON value (number, bool, null, object, array) diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index 62402923c..dd1388ec1 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -17,7 +17,9 @@ class common_chat_peg_mapper { virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); virtual void map(const common_peg_ast_node & node); - private: + protected: + virtual std::string normalize_container_value(const std::string & input); + private: // Tool call handling state std::optional pending_tool_call; // Tool call waiting for name common_chat_tool_call * current_tool = nullptr; @@ -30,6 +32,13 @@ class common_chat_peg_mapper { std::string & args_target(); }; +class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper { + public: + common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + protected: + std::string normalize_container_value(const std::string & input) override; +}; + struct content_structure; struct tool_call_structure; diff --git a/common/chat.cpp b/common/chat.cpp index 0996dc30f..41192ce88 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -709,6 +709,8 @@ const char * common_chat_format_name(common_chat_format format) { return "peg-simple"; case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; + case COMMON_CHAT_FORMAT_PEG_GEMMA4: + return "peg-gemma4"; default: throw std::runtime_error("Unknown chat format"); } @@ -995,15 +997,19 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis")); auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1); + // Occasionally, gpt-oss-20b will prefix channels with this commentary + auto stray_commentary = p.optional(p.literal("<|channel|>commentary") + p.optional(p.literal(" to=assistant"))); + auto start_analysis = stray_commentary + p.literal("<|channel|>analysis<|message|>"); + if (extract_reasoning) { - p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end); + p.rule("analysis", start_analysis + p.reasoning(content) + end); } else { - p.rule("analysis", p.content(p.literal("<|channel|>analysis<|message|>") + content + end)); + p.rule("analysis", p.content(start_analysis + content + end)); } auto analysis = p.ref("analysis"); auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end); - auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content)); + auto final_msg = p.rule("final", stray_commentary + p.literal("<|channel|>final<|message|>") + p.content(content)); // Consume any unsolicited tool calls, e.g. builtin functions auto unsolicited = p.rule("unsolicited", p.atomic(p.optional(channel) + p.literal(" to=") + content + end)); @@ -1011,7 +1017,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp auto any = p.rule("any", preamble | analysis); if (has_response_format) { - auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type); + auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type); auto response_format = p.rule("response-format", p.literal("<|channel|>final") + constraint + p.literal("<|message|>") + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); @@ -1028,7 +1034,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp const auto & params = function.at("parameters"); auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name)); - auto constraint = p.optional(p.space() + p.literal("<|constrain|>") + constrain_type); + auto constraint = p.optional(p.space() + p.optional(p.literal("<|constrain|>")) + constrain_type); auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params)); // recipient in role header @@ -1069,6 +1075,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.grammar_triggers = { { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^\\s+to$" }, + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^<\\|channel\\|>(?:commentary|analysis)\\s+to=functions$" }, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(\\s+to)" }, { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "<\\|start\\|>assistant(<\\|channel\\|>(?:commentary|analysis)\\s+to)" } }; @@ -1555,6 +1562,50 @@ static void requires_non_null_content(json & messages) { } } +// Gemma4 uses a custom tool_responses field instead of role:tool messages. +// Convert consecutive role:tool messages into a single user message with tool_responses. +static void convert_tool_responses_gemma4(json & messages) { + json result = json::array(); + size_t i = 0; + while (i < messages.size()) { + if (messages[i].contains("role") && messages[i].at("role") == "tool") { + json tool_responses = json::array(); + while (i < messages.size() && + messages[i].contains("role") && + messages[i].at("role") == "tool") { + const auto & tool_msg = messages[i]; + std::string name; + if (tool_msg.contains("tool_call_id") && tool_msg.at("tool_call_id").is_string()) { + name = tool_msg.at("tool_call_id"); + } else if (tool_msg.contains("name") && tool_msg.at("name").is_string()) { + name = tool_msg.at("name"); + } + json response; + if (tool_msg.contains("content")) { + const auto & content = tool_msg.at("content"); + if (content.is_string()) { + // Try to parse the content as JSON; fall back to raw string + try { + response = json::parse(content.get()); + } catch (...) { + response = content; + } + } else { + response = content; + } + } + tool_responses.push_back({{"name", name}, {"response", response}}); + i++; + } + result.push_back({{"role", "user"}, {"tool_responses", tool_responses}}); + } else { + result.push_back(messages[i]); + i++; + } + } + messages = result; +} + static void func_args_not_string(json & messages) { GGML_ASSERT(messages.is_array()); for (auto & message : messages) { @@ -1683,6 +1734,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ workaround::func_args_not_string(params.messages); } + if (src.find("'<|tool_call>call:'") != std::string::npos) { + workaround::convert_tool_responses_gemma4(params.messages); + } + params.add_generation_prompt = false; std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params); params.add_generation_prompt = true; @@ -1724,7 +1779,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.generation_prompt = params.generation_prompt; auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { - return p.prefix(params.generation_prompt) + p.content(p.rest()); + return p.prefix(params.generation_prompt) << p.content(p.rest()); }); data.parser = parser.save(); return data; @@ -1867,8 +1922,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars // Try to extract any partial results from what was successfully parsed common_chat_msg msg; msg.role = "assistant"; - auto mapper = common_chat_peg_mapper(msg); - mapper.from_ast(ctx.ast, result); + std::unique_ptr mapper; + if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) { + mapper = std::make_unique(msg); + } else { + mapper = std::make_unique(msg); + } + mapper->from_ast(ctx.ast, result); if (ctx.is_debug()) { fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str()); @@ -1883,8 +1943,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars common_chat_msg msg; msg.role = "assistant"; - auto mapper = common_chat_peg_mapper(msg); - mapper.from_ast(ctx.ast, result); + std::unique_ptr mapper; + if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) { + mapper = std::make_unique(msg); + } else { + mapper = std::make_unique(msg); + } + mapper->from_ast(ctx.ast, result); if (ctx.is_debug()) { fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str()); diff --git a/common/chat.h b/common/chat.h index 6358a1893..50c73d481 100644 --- a/common/chat.h +++ b/common/chat.h @@ -184,6 +184,7 @@ enum common_chat_format { // These are intended to be parsed by the PEG parser COMMON_CHAT_FORMAT_PEG_SIMPLE, COMMON_CHAT_FORMAT_PEG_NATIVE, + COMMON_CHAT_FORMAT_PEG_GEMMA4, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/common.cpp b/common/common.cpp index 56e78ea42..d777f01e4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1449,6 +1449,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.progress_callback = params.load_progress_callback; mparams.progress_callback_user_data = params.load_progress_callback_user_data; + mparams.no_alloc = params.no_alloc; return mparams; } diff --git a/common/common.h b/common/common.h index 40f0ea568..572e6fad5 100644 --- a/common/common.h +++ b/common/common.h @@ -676,6 +676,7 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; + bool no_alloc = false; // Don't allocate model buffers }; // call once at the start of a program if it uses libcommon diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index a6d9a4c27..694f9b850 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1557,6 +1557,36 @@ static std::unordered_set collect_reachable_rules( // GBNF generation implementation void common_peg_arena::build_grammar(const common_grammar_builder & builder, bool lazy) const { + auto schema_delegates = [](const common_peg_schema_parser & s) -> bool { + if (!s.schema) { + return true; + } + if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") { + return true; + } + return false; + }; + + // Unwrap the parser so we can properly check if it's a sequence or choice + auto effective_parser = [&](common_peg_parser_id id) -> const common_peg_parser_variant & { + while (true) { + const auto & p = parsers_.at(id); + if (const auto * tag = std::get_if(&p)) { + id = tag->child; + } else if (const auto * atomic = std::get_if(&p)) { + id = atomic->child; + } else if (const auto * schema = std::get_if(&p)) { + if (schema_delegates(*schema)) { + id = schema->child; + } else { + return p; + } + } else { + return p; + } + } + }; + // Generate GBNF for a parser std::function to_gbnf = [&](common_peg_parser_id id) -> std::string { const auto & parser = parsers_.at(id); @@ -1577,7 +1607,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo s += " "; } auto child_gbnf = to_gbnf(child); - const auto & child_parser = parsers_.at(child); + const auto & child_parser = effective_parser(child); if (std::holds_alternative(child_parser) || std::holds_alternative(child_parser)) { s += "(" + child_gbnf + ")"; @@ -1593,7 +1623,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo s += " | "; } auto child_gbnf = to_gbnf(child); - const auto & child_parser = parsers_.at(child); + const auto & child_parser = effective_parser(child); if (std::holds_alternative(child_parser)) { s += "(" + child_gbnf + ")"; } else { @@ -1603,7 +1633,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return s; } else if constexpr (std::is_same_v) { auto child_gbnf = to_gbnf(p.child); - const auto & child_parser = parsers_.at(p.child); + const auto & child_parser = effective_parser(p.child); if (std::holds_alternative(child_parser) || std::holds_alternative(child_parser)) { child_gbnf = "(" + child_gbnf + ")"; @@ -1663,15 +1693,10 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo } return gbnf_excluding_pattern(p.delimiters); } else if constexpr (std::is_same_v) { - if (p.schema) { - if (p.raw && p.schema->contains("type") && p.schema->at("type").is_string() && p.schema->at("type") == "string") { - // TODO: Implement more comprehensive grammar generation for raw strings. - // For now, use the grammar emitted from the underlying parser. - return to_gbnf(p.child); - } - return builder.add_schema(p.name, *p.schema); + if (schema_delegates(p)) { + return to_gbnf(p.child); } - return to_gbnf(p.child); + return builder.add_schema(p.name, *p.schema); } else if constexpr (std::is_same_v) { return p.name; } else if constexpr (std::is_same_v) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 51f0d1ab1..de1def320 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1164,7 +1164,7 @@ class TextModel(ModelBase): if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None: self.gguf_writer.add_expert_count(n_experts) logger.info(f"gguf: expert count = {n_experts}") - if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)) is not None: + if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) logger.info(f"gguf: experts used count = {n_experts_used}") if (n_expert_groups := self.hparams.get("n_group")) is not None: @@ -6878,7 +6878,9 @@ class Gemma2Model(TextModel): @ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") class Gemma3Model(TextModel): model_arch = gguf.MODEL_ARCH.GEMMA3 - norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value + + def norm_shift(self, name: str) -> float: + return 1.0 if name.endswith("norm.weight") else 0.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): if (self.dir_model / "tokenizer.model").is_file(): @@ -6916,17 +6918,22 @@ class Gemma3Model(TextModel): # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: + n_vocab_real = -1 if (self.dir_model / "tokenizer.model").is_file(): tokens = self._create_vocab_sentencepiece()[0] + n_vocab_real = len(tokens) else: - tokens = self.get_vocab_base()[0] - data_torch = data_torch[:len(tokens)] + with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + n_vocab_real = len(tokenizer_json["model"]["vocab"]) + len(tokenizer_json["added_tokens"]) + data_torch = data_torch[:n_vocab_real] # ref code in Gemma3RMSNorm # output = output * (1.0 + self.weight.float()) # note: this is not the case on gemma3n - if name.endswith("norm.weight"): - data_torch = data_torch + self.norm_shift + f_shift = self.norm_shift(name) + if f_shift != 0.0: + data_torch = data_torch + f_shift yield from super().modify_tensors(data_torch, name, bid) @@ -7100,7 +7107,8 @@ class ConformerAudioModel(MmprojModel): assert data_torch.shape[2] == 1 data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1]) - yield from super().modify_tensors(data_torch, name, bid) + mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min")) + yield (mapped_name, data_torch) @ModelBase.register("DeepseekOCRForCausalLM") @@ -7289,7 +7297,6 @@ class Gemma3nVisionAudioModel(ConformerAudioModel): @ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA3N - norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code _altup_proj: list[Tensor] = [] _altup_unembd: list[Tensor] = [] @@ -7308,6 +7315,10 @@ class Gemma3NModel(Gemma3Model): torch.Tensor(), # to be replaced ] + def norm_shift(self, name: str) -> float: + del name + return 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code + def set_vocab(self): # For Gemma3n multimodal models, we need the FULL vocab_size (262400) # which includes special tokens from 262144-262399 for vision/audio. @@ -7425,6 +7436,212 @@ class Gemma3NModel(Gemma3Model): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Gemma4ForConditionalGeneration") +class Gemma4Model(Gemma3Model): + model_arch = gguf.MODEL_ARCH.GEMMA4 + + def norm_shift(self, name: str) -> float: + del name # unused + return 0.0 + + def set_vocab(self): + vocab = gguf.LlamaHfVocab(self.dir_model) + tokens = [] + scores = [] + toktypes = [] + visible_tokens = {"<|channel>", "", "<|tool_call>", "", "<|tool_response>", "", "<|\"|>"} + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + text_str = text.decode() + if text_str in visible_tokens: + # always render these tokens, so that the chat parser can read them + toktypes.append(gguf.TokenType.USER_DEFINED) + logger.info(f"Token '{text_str}' is set to USER_DEFINED") + else: + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + # TODO @ngxson : there are some known (rare) issues with the tokenizer during development + # but I don't have time to dive into them right now; + # using a dedicated tokenizer name so that we can fix later without re-converting GGUF + self.gguf_writer.add_tokenizer_model("gemma4") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + self.gguf_writer.add_add_space_prefix(False) + self.gguf_writer.add_add_bos_token(False) # already added via the chat template + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + num_kv_shared_layers = self.hparams["num_kv_shared_layers"] + self.gguf_writer.add_shared_kv_layers(num_kv_shared_layers) + + # per-layer embedding is optional + n_pl_embd = self.hparams.get("hidden_size_per_layer_input") or 0 + self.gguf_writer.add_embedding_length_per_layer_input(n_pl_embd) + + swa_layers = [t == "sliding_attention" for t in self.hparams["layer_types"]] + self.gguf_writer.add_sliding_window_pattern(swa_layers) + + head_dim_full = self.hparams["global_head_dim"] + head_dim_swa = self.hparams["head_dim"] + # correct the head dim for global/swa layers + self.gguf_writer.add_key_length(head_dim_full) + self.gguf_writer.add_value_length(head_dim_full) + self.gguf_writer.add_key_length_swa(head_dim_swa) + self.gguf_writer.add_value_length_swa(head_dim_swa) + + expert_intermediate_size = self.find_hparam(["expert_intermediate_size", "moe_intermediate_size"]) + if expert_intermediate_size is not None: + self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size) + + # if use_double_wide_mlp is set, we need to adjust the value for kv shared layers + use_double_wide_mlp = self.hparams.get("use_double_wide_mlp", False) + first_kv_shared_layer_idx = self.block_count - num_kv_shared_layers + if use_double_wide_mlp: + n_ff = self.hparams["intermediate_size"] + n_ff_arr = [n_ff if il < first_kv_shared_layer_idx else n_ff * 2 for il in range(self.block_count)] + self.gguf_writer.add_feed_forward_length(n_ff_arr) + + # handle num_global_key_value_heads + num_key_value_heads_full = self.hparams.get("num_global_key_value_heads") + num_key_value_heads_swa = self.hparams.get("num_key_value_heads") + if num_key_value_heads_full is not None and num_key_value_heads_swa is not None: + value_arr = [num_key_value_heads_swa if is_swa else num_key_value_heads_full for is_swa in swa_layers] + self.gguf_writer.add_head_count_kv(value_arr) + + # handle n_rot differently for global vs swa layers + partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0) + n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors + n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa) + self.gguf_writer.add_rope_dimension_count(n_rot_full) + self.gguf_writer.add_rope_dimension_count_swa(n_rot_swa) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # full layer uses "proportional" rope with partial_rotary_factor=0.25 + # the expected ordering is cc000000ss000000 (c = cos, s = sin, 0 = unrotated), + # but ggml neox only supports ccss000000000000, and we cannot rearrange the head because that will break use_alternative_attention + # solution is to set specific freq_factors for the unrotated dims + + # IMPORTANT: this ROPE_FREQS tensor is ONLY used by the full_attention layers + rope_params_full = self.hparams["rope_parameters"]["full_attention"] + assert rope_params_full["rope_type"] == "proportional" + head_dim_full = (self.hparams["global_head_dim"]) + partial_rotary_factor_full = rope_params_full["partial_rotary_factor"] + n_rot_full = int(head_dim_full * partial_rotary_factor_full / 2) + n_unrot_full = int(head_dim_full / 2) - n_rot_full + values = [1.0] * n_rot_full + [1e30] * n_unrot_full + rope_freqs_full = torch.tensor(values, dtype=torch.float32) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith("per_dim_scale") or name.endswith("layer_scalar"): + name = name + ".weight" + + if "language_model." not in name and "rope_freqs" not in name: + return # skip non-language model tensors + + name = name.replace("language_model.", "") + if name.endswith("router.scale"): + name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid, ".scale") + yield (name, data_torch) + return + if ".per_expert_scale" in name: + # convert per-expert scale to FFN down scale + name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_EXP, bid, ".scale") + yield (name, data_torch) + return + if ".experts." in name and not name.endswith(".weight"): + name += ".weight" + + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Gemma4ForConditionalGeneration") +class Gemma4VisionAudioModel(MmprojModel): + has_audio_encoder = True + has_vision_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = 224 # unused, but set to avoid error + + # remap audio hparams + if self.hparams_audio: + self.hparams_audio["feat_in"] = self.hparams_audio.get("input_feat_size", 128) + self.hparams_audio["intermediate_size"] = self.hparams_audio["hidden_size"] * 4 + else: + self.has_audio_encoder = False + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # vision params + self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.GEMMA4V) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + + # audio params + if self.hparams_audio: + self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A) + self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"]) + self.gguf_writer.add_audio_attention_layernorm_eps(1e-5) + + def is_audio_tensor(self, name: str) -> bool: + return "audio_tower" in name or "embed_audio" in name + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if self.is_audio_tensor(name): + if ".conv" in name or "_conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F32 + if "position_embedding_table" in name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("model.language_model."): + return # skip + + if len(data_torch.shape) == 0: + # convert scalar tensors (input/output_mix/max) to 1D tensors + data_torch = data_torch.unsqueeze(0) + + if self.is_audio_tensor(name): + assert self.hparams_audio is not None + name = name.replace("model.audio_tower.", "conformer.") + name = name.replace(".linear.", ".") + if name.endswith("per_dim_key_scale") or name.endswith("per_dim_scale"): + name = name + ".weight" + data_torch = torch.nn.functional.softplus(data_torch) + if "lconv1d.depthwise_conv1d" in name and name.endswith(".weight"): + assert data_torch.shape[1] == 1 + data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2]) + mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min")) + yield (mapped_name, data_torch) + + else: + name = name.replace("model.vision_tower.encoder.", "vision_model.model.") + name = name.replace(".linear.weight", ".weight") + if name.endswith("layer_scalar") or name.endswith("position_embedding_table"): + name = name + ".weight" + if name.endswith("patch_embedder.input_proj.weight"): + n_embd, ksize_sq_c = data_torch.shape + patch_size = int((ksize_sq_c // 3) ** 0.5) + data_torch = data_torch.reshape(n_embd, patch_size, patch_size, 3) + data_torch = data_torch.permute(0, 3, 1, 2).contiguous() + mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min")) + yield (mapped_name, data_torch) + + @ModelBase.register("Starcoder2ForCausalLM") class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/embd_res/klite.embd b/embd_res/klite.embd index 1698be7a7..438bb9047 100644 --- a/embd_res/klite.embd +++ b/embd_res/klite.embd @@ -12,7 +12,7 @@ Current version indicated by LITEVER below. -->