diff --git a/common/arg.cpp b/common/arg.cpp index 4c837dfff..210b3e8c0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1833,23 +1833,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"--grammar"}, "GRAMMAR", - string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), + "BNF-like grammar to constrain generations (see samples in grammars/ dir)", [](common_params & params, const std::string & value) { - params.sampling.grammar = value; + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, value}; } ).set_sparam()); add_opt(common_arg( {"--grammar-file"}, "FNAME", "file to read grammar from", [](common_params & params, const std::string & value) { - params.sampling.grammar = read_file(value); + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, read_file(value)}; } ).set_sparam()); add_opt(common_arg( {"-j", "--json-schema"}, "SCHEMA", "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", [](common_params & params, const std::string & value) { - params.sampling.grammar = json_schema_to_grammar(json::parse(value)); + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(value))}; } ).set_sparam()); add_opt(common_arg( @@ -1866,7 +1866,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::istreambuf_iterator(), std::back_inserter(schema) ); - params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, json_schema_to_grammar(json::parse(schema))}; } ).set_sparam()); add_opt(common_arg( diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index f19819494..aa03aea5a 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -1,3 +1,4 @@ +#include "chat-auto-parser-helpers.h" #include "chat-auto-parser.h" #include "chat-peg-parser.h" #include "chat.h" @@ -23,13 +24,13 @@ static void foreach_function(const json & tools, const std::functionreasoning) + return p.optional(start + p.reasoning(p.until(end)) + end + p.space()); + } + // Delimiter-style (empty start) + return p.optional(p.reasoning(p.until(end)) + end + p.space()); } - } else if (mode == reasoning_mode::DELIMITER) { - return p.optional(p.reasoning(p.until(end)) + end); } return p.eps(); @@ -335,7 +320,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte "tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) : p.tool_arg_json_value(p.schema( - p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) + + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) + p.space()) + p.tool_arg_close(p.literal(arguments.value_suffix))); @@ -384,7 +369,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + call_id_section) + p.space() + args_seq; matched_atomic = true; - } else if (!arguments.name_prefix.empty() && properties.size() > 0) { + } else if (!arguments.name_prefix.empty() && !required_parsers.empty()) { + // Only peek for an arg tag when there are required args that must follow. + // When all args are optional, the model may emit no arg tags at all (#20650). func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq; matched_atomic = true; diff --git a/common/chat-auto-parser-helpers.cpp b/common/chat-auto-parser-helpers.cpp index 787d7bab9..3a7a5c13a 100644 --- a/common/chat-auto-parser-helpers.cpp +++ b/common/chat-auto-parser-helpers.cpp @@ -1,9 +1,11 @@ #include "chat-auto-parser-helpers.h" #include "chat-auto-parser.h" +#include "chat-peg-parser.h" #include "chat.h" #include "log.h" #include "nlohmann/json.hpp" +#include "peg-parser.h" #include #include @@ -186,6 +188,21 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri result.suffix = ""; // pick prefix = all as representation } + + // When left has no unique content (result.left is empty), left is entirely + // shared with right. The simultaneous prefix/suffix segment matching can + // incorrectly consume trailing segments of left as suffix when those same + // segments also appear at the end of right (e.g. "\n" at the end of both + // the shared content and the generation prompt). This rotates the diff. + // Fix: if left is a prefix of right, enforce that directly. + if (result.left.empty() && !result.right.empty() && + left.size() <= right.size() && + right.substr(0, left.size()) == left) { + result.prefix = left; + result.suffix = ""; + result.right = right.substr(left.size()); + } + return result; } @@ -291,10 +308,26 @@ std::vector prune_whitespace_segments(const std::vector & segm return result; } +common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p, + const common_peg_parser & prs, + const autoparser::generation_params & inputs, + const std::string & reasoning_start) { + auto parser = prs; + if (!inputs.generation_prompt.empty()) { + size_t end_pos = inputs.generation_prompt.size(); + if (!reasoning_start.empty() && inputs.generation_prompt.find(reasoning_start) != std::string::npos) { + end_pos = inputs.generation_prompt.find(reasoning_start); + } + std::string cut_genprompt = inputs.generation_prompt.substr(0, end_pos); + parser = p.literal(cut_genprompt) + parser; + } + return parser; +} + namespace autoparser { std::string apply_template(const common_chat_template & tmpl, const template_params & params) { - templates_params tmpl_params; + generation_params tmpl_params; tmpl_params.messages = params.messages; tmpl_params.tools = params.tools; tmpl_params.add_generation_prompt = params.add_generation_prompt; diff --git a/common/chat-auto-parser-helpers.h b/common/chat-auto-parser-helpers.h index 6e3df79db..e13581e58 100644 --- a/common/chat-auto-parser-helpers.h +++ b/common/chat-auto-parser-helpers.h @@ -1,6 +1,7 @@ #pragma once #include "chat-auto-parser.h" +#include "peg-parser.h" #include #include #include @@ -57,6 +58,11 @@ std::vector segmentize_markers(const std::string & text); // (MARKER, ""), (MARKER, "") ] std::vector prune_whitespace_segments(const std::vector & segments); +// Wrap parser with generation prompt parser +common_peg_parser wrap_for_generation_prompt(common_chat_peg_builder & p, + const common_peg_parser & prs, + const autoparser::generation_params & inputs, + const std::string & reasoning_start = {}); namespace autoparser { // Apply a template with the given parameters, returning the rendered string (empty on failure) diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 52c6488f4..73888276f 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -50,7 +50,7 @@ namespace autoparser { // High-level params for parser generation // ============================================================================ -struct templates_params { +struct generation_params { json messages; json tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; @@ -62,6 +62,7 @@ struct templates_params { bool add_generation_prompt = false; bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::string generation_prompt; json extra_context; bool add_bos = false; bool add_eos = false; @@ -77,11 +78,7 @@ struct templates_params { // Reasoning handling mode (derived from R1-R3 comparisons) enum class reasoning_mode { NONE, // No reasoning markers detected - TAG_BASED, // Standard tag-based: ... - DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter) - FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end) - FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but - // with both opened and closed tag for disabled thinking + TAG_BASED, // Tag-based: ... (start can be empty for delimiter-style) TOOLS_ONLY // Only reason on tool calls, not on normal content }; @@ -91,12 +88,6 @@ inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode) return os << "NONE"; case reasoning_mode::TAG_BASED: return os << "TAG_BASED"; - case reasoning_mode::DELIMITER: - return os << "DELIMITER"; - case reasoning_mode::FORCED_OPEN: - return os << "FORCED_OPEN"; - case reasoning_mode::FORCED_CLOSED: - return os << "FORCED_CLOSED"; case reasoning_mode::TOOLS_ONLY: return os << "TOOLS_ONLY"; default: @@ -184,7 +175,6 @@ struct tool_format_analysis { bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "": { ... arguments ... } } bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...] - bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings) std::string function_field = "function"; std::string name_field = "name"; @@ -225,12 +215,12 @@ struct analyze_content; struct parser_build_context { common_chat_peg_builder & p; - const templates_params & inputs; + const generation_params & inputs; common_peg_parser reasoning_parser; bool extracting_reasoning = false; const analyze_content * content = nullptr; - parser_build_context(common_chat_peg_builder & p, const templates_params & inputs); + parser_build_context(common_chat_peg_builder & p, const generation_params & inputs); }; // ============================================================================ @@ -260,6 +250,7 @@ struct analyze_reasoning : analyze_base { analyze_reasoning() = default; analyze_reasoning(const common_chat_template & tmpl, bool supports_tools); + analyze_reasoning(std::string start_, std::string end_) : start(std::move(start_)), end(std::move(end_)) {} common_peg_parser build_parser(parser_build_context & ctx) const override; @@ -381,7 +372,7 @@ struct autoparser { void analyze_template(const common_chat_template & tmpl); // Build the PEG parser for this template - common_peg_arena build_parser(const templates_params & inputs) const; + common_peg_arena build_parser(const generation_params & inputs) const; private: // Collect tokens from entire analysis to preserve @@ -395,10 +386,10 @@ struct autoparser { class peg_generator { public: static common_chat_params generate_parser(const common_chat_template & tmpl, - const struct templates_params & inputs); + const struct generation_params & inputs); static common_chat_params generate_parser(const common_chat_template & tmpl, - const struct templates_params & inputs, + const struct generation_params & inputs, const autoparser & autoparser); }; diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index 05b3b6b6a..4b827c9ae 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -2,6 +2,7 @@ #include "chat-auto-parser-helpers.h" #include "chat-peg-parser.h" #include "chat.h" +#include "common.h" #include "log.h" #include "nlohmann/json.hpp" #include "peg-parser.h" @@ -31,8 +32,9 @@ static std::vector void { if (tmpl.src.find("content.split('')") != std::string::npos && tmpl.src.find("reasoning_content") == std::string::npos && + tmpl.src.find("") == std::string::npos && analysis.reasoning.mode == reasoning_mode::NONE) { - analysis.reasoning.mode = reasoning_mode::FORCED_OPEN; + analysis.reasoning.mode = reasoning_mode::TAG_BASED; analysis.reasoning.start = ""; analysis.reasoning.end = ""; analysis.preserved_tokens.push_back(""); @@ -185,7 +187,6 @@ void autoparser::analyze_template(const common_chat_template & tmpl) { LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str()); LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str()); LOG_DBG("func_close: '%s'\n", tools.function.close.c_str()); - LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false"); LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str()); LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str()); LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str()); @@ -295,16 +296,12 @@ void analyze_reasoning::compare_reasoning_presence() { } if (result.result.success()) { if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { - if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close - mode = reasoning_mode::TAG_BASED; - } else { - mode = reasoning_mode::FORCED_CLOSED; - } + mode = reasoning_mode::TAG_BASED; start = trim_whitespace(result.tags["pre"]); - end = result.tags["post"]; + end = trim_trailing_whitespace(result.tags["post"]); } else if (!result.tags["post"].empty()) { - mode = reasoning_mode::DELIMITER; - end = result.tags["post"]; + mode = reasoning_mode::TAG_BASED; + end = trim_trailing_whitespace(result.tags["post"]); } } } @@ -331,53 +328,30 @@ void analyze_reasoning::compare_thinking_enabled() { const auto & diff = comparison->diff; std::string left_trimmed = trim_whitespace(diff.left); + std::string right_trimmed = trim_whitespace(diff.right); if (left_trimmed.empty() && !diff.right.empty()) { - std::string right_trimmed = trim_whitespace(diff.right); - if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { if (start.empty()) { start = right_trimmed; - mode = reasoning_mode::FORCED_OPEN; + mode = reasoning_mode::TAG_BASED; + } + } + } else if (right_trimmed.empty() && !diff.left.empty()) { + if (!left_trimmed.empty() && string_ends_with(comparison->output_A, left_trimmed)) { + if (end.empty()) { + auto seg = prune_whitespace_segments(segmentize_markers(comparison->output_A)); + if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) { + start = seg[seg.size() - 2].value; + } + end = left_trimmed; + mode = reasoning_mode::TAG_BASED; } } } - if (start.empty() && !end.empty()) { - mode = reasoning_mode::DELIMITER; - } - - // Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers, - // but enable_thinking=true produces only the start marker - if (!comparison->output_A.empty() && !comparison->output_B.empty()) { - auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) { - return p.literal(start) + p.space() + p.literal(end) + p.rest(); - }); - auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) { - return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest(); - }); - if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() && - parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) { - mode = reasoning_mode::FORCED_CLOSED; - } else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier - auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A); - if (result.result.success()) { - start = result.tags["pre"]; - mode = reasoning_mode::FORCED_CLOSED; - } - } - } - - if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close" - if (!diff.left.empty() && !diff.right.empty()) { - auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left)); - auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right)); - if (seg_A.size() == 1 && seg_B.size() == 1) { - mode = reasoning_mode::FORCED_CLOSED; - start = seg_B[0].value; - end = seg_A[0].value; - } - } + if (mode == reasoning_mode::NONE && start.empty() && !end.empty()) { + mode = reasoning_mode::TAG_BASED; } } @@ -426,16 +400,16 @@ void analyze_reasoning::compare_reasoning_scope() { auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); if (result.result.success()) { start = result.tags["pre"]; - end = result.tags["post"]; + end = trim_trailing_whitespace(result.tags["post"]); } else { auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space()))); }); result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); if (result.result.success()) { - end = result.tags["post"]; + end = trim_trailing_whitespace(result.tags["post"]); } else { - LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); + LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); mode = reasoning_mode::NONE; } } @@ -600,33 +574,23 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack, return; } - enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES }; - - auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style { + auto in_json_haystack = [&haystack](const std::string & needle) -> bool { auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({ - p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")), p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) }); }); auto result = parser.parse_anywhere_and_extract(haystack); - if (!result.result.success()) { - return json_quote_style::NONE; - } - return result.tags.count("sq") && !result.tags["sq"].empty() - ? json_quote_style::SINGLE_QUOTES - : json_quote_style::DOUBLE_QUOTES; + return result.result.success(); }; auto fun_quote = in_json_haystack(fun_name_needle); auto arg_quote = in_json_haystack(arg_name_needle); - if (fun_quote != json_quote_style::NONE) { + if (fun_quote) { // no need to check further, we're in JSON land format.mode = tool_format::JSON_NATIVE; - format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES); - } else if (arg_quote != json_quote_style::NONE) { + } else if (arg_quote) { format.mode = tool_format::TAG_WITH_JSON; - format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES); } else { format.mode = tool_format::TAG_WITH_TAGGED; } diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 4c5bb6218..5f7d422b4 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -229,6 +229,20 @@ void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, result.tool_calls.push_back(pending_tool_call.value()); pending_tool_call.reset(); } + + // Discard whitespace-only reasoning content (e.g. from prefill) + if (!result.reasoning_content.empty()) { + bool all_whitespace = true; + for (char c : result.reasoning_content) { + if (c != ' ' && c != '\n' && c != '\r' && c != '\t') { + all_whitespace = false; + break; + } + } + if (all_whitespace) { + result.reasoning_content.clear(); + } + } } void common_chat_peg_mapper::map(const common_peg_ast_node & node) { diff --git a/common/chat.cpp b/common/chat.cpp index 0900d854c..cab5b4c44 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,5 +1,6 @@ #include "chat.h" +#include "chat-auto-parser-helpers.h" #include "chat-auto-parser.h" #include "chat-peg-parser.h" #include "common.h" @@ -33,6 +34,7 @@ #include #include #include +#include #include using json = nlohmann::ordered_json; @@ -775,7 +777,7 @@ static void foreach_parameter(const json & std::string common_chat_template_direct_apply( const common_chat_template & tmpl, - const autoparser::templates_params & inputs, + const autoparser::generation_params & inputs, const std::optional & messages_override, const std::optional & tools_override, const std::optional & additional_context) { @@ -826,7 +828,7 @@ std::string common_chat_template_direct_apply( } static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, - const autoparser::templates_params & inputs) { + const autoparser::generation_params & inputs) { common_chat_params data; // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja @@ -891,8 +893,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ // Response format parser if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { // Ministral wants to emit json surrounded by code fences - return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) - << "```"; + return wrap_for_generation_prompt(p, reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```", + inputs, "[THINK]"); } // Tool call parser @@ -912,12 +914,13 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ auto max_calls = inputs.parallel_tool_calls ? -1 : 1; auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); - return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; + return wrap_for_generation_prompt(p, reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls, + inputs, "[THINK]"); } // Content only parser include_grammar = false; - return reasoning << p.content(p.rest()); + return wrap_for_generation_prompt(p, reasoning << p.content(p.rest()), inputs, "[THINK]"); }); data.parser = parser.save(); @@ -943,7 +946,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ } static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, - const autoparser::templates_params & inputs) { + const autoparser::generation_params & inputs) { common_chat_params data; // Copy reasoning to the "thinking" field as expected by the gpt-oss template @@ -1003,7 +1006,8 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp p.literal("<|channel|>final") + constraint + p.literal("<|message|>") + p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema))); - return response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format); + return wrap_for_generation_prompt(p, response_format | (analysis + p.zero_or_more(start + analysis) + start + response_format), + inputs, "<|channel|>"); } if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { @@ -1035,10 +1039,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return tool_call | ( any + p.zero_or_more(start + any) + start + tool_call); } - return tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)); + return wrap_for_generation_prompt(p, tool_call | final_msg | (any + p.zero_or_more(start + any) + start + (tool_call | final_msg)), + inputs, "<|channel|>"); } - return final_msg | (any + p.zero_or_more(start + any) + start + final_msg); + return wrap_for_generation_prompt(p, final_msg | (any + p.zero_or_more(start + any) + start + final_msg), + inputs, "<|channel|>"); }); data.parser = parser.save(); @@ -1066,7 +1072,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp // Functionary v3.2 - uses recipient-based format: >>>recipient\n{content} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, - const autoparser::templates_params & inputs) { + const autoparser::generation_params & inputs) { common_chat_params data; data.prompt = common_chat_template_direct_apply(tmpl, inputs); @@ -1087,13 +1093,13 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // Build content parser for >>>all\n{content} // When tools are present, content stops before the next ">>>" (tool call) // When no tools, content goes until end - auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>")); - auto content_until_end = p.literal(">>>all\n") + p.content(p.rest()); + auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>")); + auto content_until_end = p.literal("all\n") + p.content(p.rest()); // If no tools or tool_choice is NONE, just parse content if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { // When no tools, just match the prefix and capture everything after - return content_until_end + p.end(); + return wrap_for_generation_prompt(p, content_until_end + p.end(), inputs); } // Build tool call parsers for each available function @@ -1105,7 +1111,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // Tool format: >>>function_name\n{json_args} auto tool_parser = p.tool( - p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) + + p.tool_open(p.tool_name(p.literal(name)) + p.literal("\n")) + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) ); @@ -1116,17 +1122,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice)); auto content_and_tools = content_until_tool + tools_only; + auto ret = p.eps(); if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { if (inputs.parallel_tool_calls) { - return p.choice({ content_and_tools, tools_only }) + p.end(); + ret = p.choice({ content_and_tools, tools_only }) + p.end(); + } else { + ret = p.choice({ content_until_tool + tool_choice, tools_only }) + p.end(); } - return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end(); + } else if (inputs.parallel_tool_calls) { + ret = p.choice({ content_and_tools, content_only, tools_only }) + p.end(); + } else { + auto content_and_tool = content_until_tool + tool_choice; + ret = p.choice({ content_and_tool, content_only, tool_choice }) + p.end(); } - if (inputs.parallel_tool_calls) { - return p.choice({ content_and_tools, content_only, tools_only }) + p.end(); - } - auto content_and_tool = content_until_tool + tool_choice; - return p.choice({ content_and_tool, content_only, tool_choice }) + p.end(); + return wrap_for_generation_prompt(p, ret, inputs); }); data.parser = parser.save(); @@ -1156,14 +1165,12 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // Kimi K2 Thinking - uses unique tool call ID format: functions.: // The ID contains both the function name and an incrementing counter static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, - const autoparser::templates_params & inputs) { + const autoparser::generation_params & inputs) { common_chat_params data; data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; - data.thinking_start_tag = ""; - data.thinking_end_tag = ""; data.preserved_tokens = { "<|tool_calls_section_begin|>", "<|tool_calls_section_end|>", @@ -1178,6 +1185,18 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>"; + const std::string SECTION_END = "<|tool_calls_section_end|>"; + const std::string CALL_BEGIN = "<|tool_call_begin|>"; + const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>"; + const std::string CALL_END = "<|tool_call_end|>"; + + const std::string THINK_START = ""; + const std::string THINK_END = ""; + + data.thinking_start_tag = THINK_START; + data.thinking_end_tag = THINK_END; + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Kimi K2 Thinking format: // - Reasoning: {reasoning} @@ -1189,16 +1208,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp // <|tool_calls_section_end|> // The ID format is: functions.: where counter is 0, 1, 2, ... - // Tool call markers - const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>"; - const std::string SECTION_END = "<|tool_calls_section_end|>"; - const std::string CALL_BEGIN = "<|tool_call_begin|>"; - const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>"; - const std::string CALL_END = "<|tool_call_end|>"; - - const std::string THINK_START = ""; - const std::string THINK_END = ""; - + // Tool call markers auto end = p.end(); // Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny. @@ -1210,7 +1220,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp // Content only parser (no tools) if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return reasoning + p.content(p.rest()) + end; + return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, + inputs, THINK_START); } // Build tool call parsers for each available function @@ -1246,7 +1257,8 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN })); - return reasoning + content_before_tools + tool_calls + end; + return wrap_for_generation_prompt(p, reasoning + content_before_tools + tool_calls + end, + inputs, THINK_START); }); data.parser = parser.save(); @@ -1276,7 +1288,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp // - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|> // Tool calls can appear multiple times (parallel tool calls) static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, - const autoparser::templates_params & inputs) { + const autoparser::generation_params & inputs) { common_chat_params data; data.prompt = common_chat_template_direct_apply(tmpl, inputs); @@ -1295,13 +1307,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; - const std::string TOOL_CALL_START = "<|tool_call_start|>"; const std::string TOOL_CALL_END = "<|tool_call_end|>"; const std::string THINK_START = ""; const std::string THINK_END = ""; - auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + data.thinking_start_tag = THINK_START; + data.thinking_end_tag = THINK_END; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto end = p.end(); auto reasoning = p.eps(); @@ -1310,7 +1324,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat } if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return reasoning + p.content(p.rest()) + end; + return wrap_for_generation_prompt(p, reasoning + p.content(p.rest()) + end, inputs, + THINK_START); } auto tool_calls = p.rule("tool-calls", @@ -1322,7 +1337,8 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat auto content = p.content(p.until(TOOL_CALL_START)); - return reasoning + content + tool_calls + end; + return wrap_for_generation_prompt(p, reasoning + content + tool_calls + end, inputs, + THINK_START); }); data.parser = parser.save(); @@ -1348,7 +1364,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat static common_chat_params common_chat_params_init_gigachat_v3( const common_chat_template & tmpl, - const autoparser::templates_params & inputs) { + const autoparser::generation_params & inputs) { common_chat_params data; @@ -1362,9 +1378,10 @@ static common_chat_params common_chat_params_init_gigachat_v3( auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; - auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; + const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto ret = p.eps(); if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { // Build a choice of all available tools auto tool_choice = p.choice(); @@ -1387,13 +1404,14 @@ static common_chat_params common_chat_params_init_gigachat_v3( auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice); auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls)); - return p.content(p.until("<|message_sep|>\n\n")) << tool_calls; + ret = p.content(p.until("<|message_sep|>\n\n")) << tool_calls; + } else { + // Content only parser + include_grammar = false; + ret = p.content(p.rest()); } - // Content only parser - include_grammar = false; - return p.content(p.rest()); - + return wrap_for_generation_prompt(p, ret, inputs); }); data.parser = parser.save(); @@ -1488,87 +1506,10 @@ static json common_chat_extra_context() { return ctx; } -static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) { - autoparser::templates_params params; - params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); - const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use - ? *tmpls->template_tool_use - : *tmpls->template_default; - const auto & src = tmpl.source(); - const auto & caps = tmpl.original_caps(); - params.messages = render_message_to_json(inputs.messages, tmpl.original_caps()); - params.add_generation_prompt = inputs.add_generation_prompt; - params.tool_choice = inputs.tool_choice; - params.reasoning_format = inputs.reasoning_format; - params.enable_thinking = inputs.enable_thinking; - params.grammar = inputs.grammar; - params.now = inputs.now; - params.add_bos = tmpls->add_bos; - params.add_eos = tmpls->add_eos; - - if (src.find("<|channel|>") == std::string::npos) { - // map developer to system for all models except for GPT-OSS - workaround::map_developer_role_to_system(params.messages); - } - - if (!tmpl.original_caps().supports_system_role) { - workaround::system_message_not_supported(params.messages); - } - - if (tmpl.original_caps().supports_tool_calls) { - // some templates will require the content field in tool call messages - // to still be non-null, this puts an empty string everywhere where the - // content field is null - workaround::requires_non_null_content(params.messages); - } - - if (tmpl.original_caps().supports_object_arguments) { - workaround::func_args_not_string(params.messages); - } - - params.extra_context = common_chat_extra_context(); - for (auto el : inputs.chat_template_kwargs) { - params.extra_context[el.first] = json::parse(el.second); - } - - if (!inputs.json_schema.empty()) { - params.json_schema = json::parse(inputs.json_schema); - } - - // if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - // LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - // params.parallel_tool_calls = false; - // } else { - params.parallel_tool_calls = inputs.parallel_tool_calls; - //} - - if (params.tools.is_array()) { - if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { - throw std::runtime_error("Cannot specify grammar with tools"); - } - if (caps.supports_tool_calls && !caps.supports_tools) { - LOG_WRN( - "Template supports tool calls but does not natively describe tools. The fallback behaviour used may " - "produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); - } - } - - if (inputs.force_pure_content) { - LOG_WRN("Forcing pure content template, will not render reasoning or tools separately."); - // Create the result structure - common_chat_params data; - auto params_copy = params; - params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE; - data.prompt = common_chat_template_direct_apply(tmpl, params_copy); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - auto parser = build_chat_peg_parser([](common_chat_peg_builder &p) { - return p.content(p.rest()); - }); - data.parser = parser.save(); - return data; - } - +static std::optional try_specialized_template( + const common_chat_template & tmpl, + const std::string & src, + const autoparser::generation_params & params) { // Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser // Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos && @@ -1609,14 +1550,105 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ // GigaChatV3 format detection if (src.find("<|role_sep|>") != std::string::npos && src.find("<|message_sep|>") != std::string::npos && - src.find("<|function_call|>") == std::string::npos - ) { + src.find("<|function_call|>") == std::string::npos) { LOG_DBG("Using specialized template: GigaChatV3\n"); return common_chat_params_init_gigachat_v3(tmpl, params); } + return std::nullopt; +} + +static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { + autoparser::generation_params params; + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + const auto & tmpl = + params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default; + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); + params.messages = render_message_to_json(inputs.messages, tmpl.original_caps()); + params.tool_choice = inputs.tool_choice; + params.reasoning_format = inputs.reasoning_format; + params.enable_thinking = inputs.enable_thinking; + params.grammar = inputs.grammar; + params.now = inputs.now; + params.add_bos = tmpls->add_bos; + params.add_eos = tmpls->add_eos; + + if (src.find("<|channel|>") == std::string::npos) { + // map developer to system for all models except for GPT-OSS + workaround::map_developer_role_to_system(params.messages); + } + + if (!tmpl.original_caps().supports_system_role) { + workaround::system_message_not_supported(params.messages); + } + + if (tmpl.original_caps().supports_tool_calls) { + // some templates will require the content field in tool call messages + // to still be non-null, this puts an empty string everywhere where the + // content field is null + workaround::requires_non_null_content(params.messages); + } + + if (tmpl.original_caps().supports_object_arguments) { + workaround::func_args_not_string(params.messages); + } + + params.add_generation_prompt = false; + std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params); + params.add_generation_prompt = true; + std::string gen_prompt = common_chat_template_direct_apply(tmpl, params); + auto diff = calculate_diff_split(no_gen_prompt, gen_prompt); + params.generation_prompt = diff.right; + + params.add_generation_prompt = inputs.add_generation_prompt; + + params.extra_context = common_chat_extra_context(); + for (auto el : inputs.chat_template_kwargs) { + params.extra_context[el.first] = json::parse(el.second); + } + + if (!inputs.json_schema.empty()) { + params.json_schema = json::parse(inputs.json_schema); + } + + params.parallel_tool_calls = inputs.parallel_tool_calls; + + if (params.tools.is_array()) { + if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + if (caps.supports_tool_calls && !caps.supports_tools) { + LOG_WRN( + "Template supports tool calls but does not natively describe tools. The fallback behaviour used may " + "produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + } + } + + if (inputs.force_pure_content) { + LOG_WRN("Forcing pure content template, will not render reasoning or tools separately."); + // Create the result structure + common_chat_params data; + auto params_copy = params; + params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE; + data.prompt = common_chat_template_direct_apply(tmpl, params_copy); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.generation_prompt = params.generation_prompt; + auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { + return wrap_for_generation_prompt(p, p.content(p.rest()), params); + }); + data.parser = parser.save(); + return data; + } + + if (auto result = try_specialized_template(tmpl, src, params)) { + result->generation_prompt = params.generation_prompt; + return *result; + } + try { - LOG_DBG("Using differential autoparser\n"); + LOG_DBG("%s: using differential autoparser\n", __func__); struct autoparser::autoparser autoparser; autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); @@ -1624,13 +1656,11 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ if (auto_params.supports_thinking) { auto_params.thinking_start_tag = autoparser.reasoning.start; auto_params.thinking_end_tag = autoparser.reasoning.end; - // FORCED_OPEN and FORCED_CLOSED both put in the generation prompt - // (FORCED_CLOSED forces empty when thinking is disabled, - // but forces open when thinking is enabled) - auto_params.thinking_forced_open = - autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN || - autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED; } + auto_params.generation_prompt = params.generation_prompt; + common_peg_arena arena; + arena.load(auto_params.parser); + LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str()); return auto_params; } catch (const std::exception & e) { throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); @@ -1728,14 +1758,18 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars LOG_DBG("No parser definition detected, assuming pure content parser."); } - LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str()); + const std::string effective_input = params.generation_prompt.empty() + ? input + : params.generation_prompt + input; + + LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str()); common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT; if (params.debug) { flags |= COMMON_PEG_PARSE_FLAG_DEBUG; } - common_peg_parse_context ctx(input, flags); + common_peg_parse_context ctx(effective_input, flags); auto result = parser.parse(ctx); if (result.fail()) { @@ -1755,7 +1789,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars return msg; } throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " + - input.substr(result.end)); + effective_input.substr(result.end)); } common_chat_msg msg; diff --git a/common/chat.h b/common/chat.h index 23e80baf6..6358a1893 100644 --- a/common/chat.h +++ b/common/chat.h @@ -24,7 +24,7 @@ using json = nlohmann::ordered_json; struct common_chat_templates; namespace autoparser { -struct templates_params; +struct generation_params; } // namespace autoparser struct common_chat_tool_call { @@ -212,7 +212,7 @@ struct common_chat_params { std::string prompt; std::string grammar; bool grammar_lazy = false; - bool thinking_forced_open = false; + std::string generation_prompt; bool supports_thinking = false; std::string thinking_start_tag; // e.g., "" std::string thinking_end_tag; // e.g., "" @@ -229,14 +229,14 @@ struct common_chat_parser_params { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; - bool thinking_forced_open = false; + std::string generation_prompt; bool parse_tool_calls = true; bool debug = false; // Enable debug output for PEG parser common_peg_arena parser = {}; common_chat_parser_params() = default; common_chat_parser_params(const common_chat_params & chat_params) { - format = chat_params.format; - thinking_forced_open = chat_params.thinking_forced_open; + format = chat_params.format; + generation_prompt = chat_params.generation_prompt; } }; @@ -302,7 +302,7 @@ std::map common_chat_templates_get_caps(const common_chat_tem std::string common_chat_template_direct_apply( const common_chat_template & tmpl, - const autoparser::templates_params & inputs, + const autoparser::generation_params & inputs, const std::optional & messages_override = std::nullopt, const std::optional & tools_override = std::nullopt, const std::optional & additional_context = std::nullopt); diff --git a/common/common.h b/common/common.h index f38bb332b..3e889e739 100644 --- a/common/common.h +++ b/common/common.h @@ -3,6 +3,7 @@ #pragma once #include "ggml-opt.h" +#include "ggml.h" #include "llama-cpp.h" #include "build-info.h" @@ -10,6 +11,7 @@ #include #include #include +#include #include #include @@ -175,6 +177,43 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type }; +// Grammar type enumeration +enum common_grammar_type { + COMMON_GRAMMAR_TYPE_NONE, // no grammar set + COMMON_GRAMMAR_TYPE_USER, // user-provided GBNF (--grammar / "grammar" API field) + COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, // auto-generated from JSON schema (--json-schema / "json_schema" API field) + COMMON_GRAMMAR_TYPE_TOOL_CALLS, // auto-generated by chat template parser for function calling +}; + +// Grammar variant struct with type and grammar string +struct common_grammar { + common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE; + std::string grammar; + + // Default constructor - no grammar + common_grammar() = default; + + // Constructor with type and grammar string + common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) { + GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty()); + } + + // Check if a grammar is set + bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); } +}; + +// Returns the raw grammar string, or empty string if no grammar is set. +inline const std::string & common_grammar_value(const common_grammar & g) { + return g.grammar; +} + +// Returns true when the generation_prompt should be prefilled into the grammar sampler. +// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled. +inline bool common_grammar_needs_prefill(const common_grammar & g) { + return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT + || g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS; +} + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -225,7 +264,7 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_TEMPERATURE, }; - std::string grammar; // optional BNF-like grammar to constrain sampling + common_grammar grammar; // optional grammar constraint (user / output-format / tool-calls) bool grammar_lazy = false; std::vector grammar_triggers; // optional triggers (for lazy grammars) std::set preserved_tokens; @@ -233,10 +272,15 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + // The assistant generation prompt already prefilled into the prompt. + // Fed to the grammar sampler (to advance past pre-existing tokens) and used + // to determine the reasoning budget sampler's initial state. + // Only applied when the grammar is of output-format or tool-calls type. + std::string generation_prompt; + // reasoning budget sampler parameters // these are populated by the server/CLI based on chat template params int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget - bool reasoning_budget_activate_immediately = false; std::vector reasoning_budget_start; // start tag token sequence std::vector reasoning_budget_end; // end tag token sequence std::vector reasoning_budget_forced; // forced sequence (message + end tag) diff --git a/common/jinja/value.h b/common/jinja/value.h index 6cbedefd9..7d164588a 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -451,7 +451,7 @@ struct value_array_t : public value_t { } protected: virtual bool equivalent(const value_t & other) const override { - return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence()); + return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), other.val_arr.end(), value_equivalence()); } }; using value_array = std::shared_ptr; @@ -587,7 +587,7 @@ struct value_object_t : public value_t { } protected: virtual bool equivalent(const value_t & other) const override { - return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence()); + return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), other.val_obj.end(), value_equivalence()); } }; using value_object = std::shared_ptr; diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index a55e4f509..2ef744278 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -163,9 +163,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) { ctx->force_pos = 0; } +// forward declaration for use in clone +static struct llama_sampler * common_reasoning_budget_init_state( + const struct llama_vocab * vocab, const std::vector & start_tokens, + const std::vector & end_tokens, const std::vector & forced_tokens, + int32_t budget, common_reasoning_budget_state initial_state); + static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; - return common_reasoning_budget_init( + return common_reasoning_budget_init_state( ctx->vocab, ctx->start_matcher.tokens, ctx->end_matcher.tokens, @@ -191,13 +197,13 @@ static struct llama_sampler_i common_reasoning_budget_i = { /* .backend_set_input = */ nullptr, }; -struct llama_sampler * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - common_reasoning_budget_state initial_state) { +static struct llama_sampler * common_reasoning_budget_init_state( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state) { // promote COUNTING with budget <= 0 to FORCING if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { initial_state = REASONING_BUDGET_FORCING; @@ -217,3 +223,41 @@ struct llama_sampler * common_reasoning_budget_init( } ); } + +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + const std::vector & prefill_tokens) { + // Determine initial state from prefill: COUNTING if the prefill begins with + // the start sequence but does not also contain the end sequence after it. + common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE; + if (!prefill_tokens.empty() && !start_tokens.empty() && + prefill_tokens.size() >= start_tokens.size() && + std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) { + initial_state = REASONING_BUDGET_COUNTING; + // If the end sequence also follows the start in the prefill, reasoning + // was opened and immediately closed — stay IDLE. + if (!end_tokens.empty() && + prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) { + auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size(); + if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() && + std::equal(end_tokens.begin(), end_tokens.end(), end_start)) { + initial_state = REASONING_BUDGET_IDLE; + } + } + } + return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); +} + +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state) { + return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); +} diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h index 08ad28248..130afdea4 100644 --- a/common/reasoning-budget.h +++ b/common/reasoning-budget.h @@ -24,14 +24,26 @@ enum common_reasoning_budget_state { // DONE: passthrough forever // // Parameters: -// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) -// start_tokens - token sequence that activates counting -// end_tokens - token sequence for natural deactivation -// forced_tokens - token sequence forced when budget expires -// budget - max tokens allowed in the reasoning block -// initial_state - initial state of the sampler (e.g. IDLE or COUNTING) -// note: COUNTING with budget <= 0 is promoted to FORCING +// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) +// start_tokens - token sequence that activates counting +// end_tokens - token sequence for natural deactivation +// forced_tokens - token sequence forced when budget expires +// budget - max tokens allowed in the reasoning block +// prefill_tokens - tokens already present in the prompt (generation prompt); +// used to determine the initial state: COUNTING if they begin +// with start_tokens (but don't also end with end_tokens), +// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING. // +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + const std::vector & prefill_tokens = {}); + +// Variant that takes an explicit initial state (used by tests and clone). +// COUNTING with budget <= 0 is promoted to FORCING. struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, diff --git a/common/sampling.cpp b/common/sampling.cpp index f849d4f61..012e21266 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,13 +1,16 @@ #include "sampling.h" #include "common.h" +#include "ggml.h" #include "log.h" #include "reasoning-budget.h" #include +#include #include #include #include +#include // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h @@ -189,9 +192,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st std::vector samplers; - if (params.grammar.compare(0, 11, "%llguidance") == 0) { + const std::string & grammar_str = common_grammar_value(params.grammar); + if (grammar_str.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE - grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); + grmr = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str()); #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE @@ -240,17 +244,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st trigger_patterns_c.push_back(regex.c_str()); } - if (!params.grammar.empty()) { + if (!grammar_str.empty()) { if (params.grammar_lazy) { - grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + grmr = llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root", trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_tokens.data(), trigger_tokens.size()); } else { - grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + grmr = llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root"); } } } + // Feed generation prompt tokens to the grammar sampler so it advances past + // tokens the template already placed in the prompt. + // Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled. + std::vector prefill_tokens; + if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) { + GGML_ASSERT(vocab != nullptr); + prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true); + if (!prefill_tokens.empty()) { + std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true); + if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) { + // Some tokenizers will add a space before the first special token, need to remove + prefill_tokens = std::vector(prefill_tokens.begin() + 1, prefill_tokens.end()); + } + } + + if (grmr) { + try { + for (const auto & token : prefill_tokens) { + llama_sampler_accept(grmr, token); + LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token); + } + } catch (std::exception &e) { + LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, + common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); + throw e; + } + } + } + // reasoning budget sampler — added first so it can force tokens before other samplers if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) { samplers.push_back(common_reasoning_budget_init( @@ -259,7 +292,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st params.reasoning_budget_end, params.reasoning_budget_forced, params.reasoning_budget_tokens, - params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE)); + prefill_tokens)); } if (params.has_logit_bias()) { diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c89e5076f..63ceb635d 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -3194,6 +3194,7 @@ class tinyBLAS_PPC { private: + __attribute__((always_inline)) inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -3204,6 +3205,7 @@ class tinyBLAS_PPC { } } + __attribute__((always_inline)) inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 482a73bc9..74abcc02d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4620,12 +4620,42 @@ static void ggml_vk_load_shaders(vk_device& device) { {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, }; + const bool use_subgroup_reduce = device->subgroup_arithmetic; for (uint32_t si = 0; si < 3; si++) { + const uint32_t S_V = gdn_sizes[si]; + GGML_ASSERT(is_pow2(S_V)); + + uint32_t lanes_per_column; + if (S_V >= 128u && device->subgroup_clustered) { + lanes_per_column = 8u; + } else { + // Use largest power-of-two that divides both S_V and subgroup_size so that + // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0. + // This means we don't need extra bounds checking logic in the shader. + lanes_per_column = std::min(S_V, device->subgroup_size); + } + + const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + size_t gdn_len; + const void * gdn_data; + if (use_subgroup_reduce && need_clustered_shader) { + gdn_len = gated_delta_net_f32_len; + gdn_data = (const void *)gated_delta_net_f32_data; + } else if (use_subgroup_reduce) { + gdn_len = gated_delta_net_f32_nocluster_len; + gdn_data = (const void *)gated_delta_net_f32_nocluster_data; + } else { + gdn_len = gated_delta_net_f32_shmem_len; + gdn_data = (const void *)gated_delta_net_f32_shmem_data; + } + + const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column; + const std::array wg_denoms = {1u, 1u, cols_per_wg}; + for (uint32_t kda = 0; kda < 2; kda++) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], - gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, - "main", 7, sizeof(vk_op_gated_delta_net_push_constants), - {1, 1, 1}, {gdn_sizes[si], kda}, 1); + gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); } } } @@ -10476,7 +10506,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, { H, n_seqs, 1u }); + pc, { H, n_seqs, S_v }); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index f008859b9..5e9f8308c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -1,11 +1,25 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_CLUSTERED +#extension GL_KHR_shader_subgroup_clustered : enable +#endif +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif +// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0, +// so no bounds checking is needed. layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint KDA = 0; +layout(constant_id = 2) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 3) const uint LANES_PER_COLUMN = 32; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN; +const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN; + +layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { uint H; @@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; -shared FLOAT_TYPE s_k[S_V]; -shared FLOAT_TYPE s_q[S_V]; -shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) +#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED +shared FLOAT_TYPE temp[SUBGROUP_SIZE]; + +// This does a reduction across groups of LANES_PER_COLUMN +FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) { + const uint lane = gl_SubgroupInvocationID; + temp[lane] = partial; + barrier(); + [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) { + FLOAT_TYPE other = temp[lane ^ s]; + barrier(); + temp[lane] += other; + barrier(); + } + const FLOAT_TYPE result = temp[lane]; + barrier(); + return result; +} +#endif + +// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant +FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) { + switch (LANES_PER_COLUMN) { + case 1u: + return partial; +#if USE_SUBGROUP_CLUSTERED + // Workaround for GLSL requiring a literal constant for the cluster size. + // The branches should all fold away. + case 2u: + return subgroupClusteredAdd(partial, 2u); + case 4u: + return subgroupClusteredAdd(partial, 4u); + case 8u: + return subgroupClusteredAdd(partial, 8u); + case 16u: + return subgroupClusteredAdd(partial, 16u); + case 32u: + return subgroupClusteredAdd(partial, 32u); + case 64u: + return subgroupClusteredAdd(partial, 64u); +#endif + default: +#if USE_SUBGROUP_ADD + return subgroupAdd(partial); +#else + return reduce_add_shmem(partial); +#endif + } +} void main() { const uint head_id = gl_WorkGroupID.x; - const uint seq_id = gl_WorkGroupID.y; - const uint col = gl_LocalInvocationID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN; + const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN); const uint iq1 = head_id % neq1; const uint iq3 = seq_id / rq3; @@ -42,9 +103,9 @@ void main() { const uint state_size = S_V * S_V; const uint state_base = (seq_id * H + head_id) * state_size; - FLOAT_TYPE state[S_V]; - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); + FLOAT_TYPE s_shard[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -53,76 +114,56 @@ void main() { const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; const uint k_off = q_off; const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; - - s_q[col] = FLOAT_TYPE(data_q[q_off + col]); - s_k[col] = FLOAT_TYPE(data_k[k_off + col]); - const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; - - if (KDA != 0) { - const uint g_base = gb_off * S_V; - s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); - } - - barrier(); - - const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); + FLOAT_TYPE k_reg[ROWS_PER_LANE]; + FLOAT_TYPE q_reg[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = FLOAT_TYPE(data_k[k_off + i]); + q_reg[r] = FLOAT_TYPE(data_q[q_off + i]); + } + + FLOAT_TYPE g_exp[ROWS_PER_LANE]; if (KDA == 0) { const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); - - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - kv_col += dot( - vec4(state[i], state[i+1], state[i+2], state[i+3]), - vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) - ); + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + g_exp[r] = g_val; } - - FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; - - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = g_val * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; - - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } - - data_dst[attn_off + col] = attn_col * scale; } else { - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - kv_col += dot(gv * sv, kv); + const uint g_base = gb_off * S_V; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i])); } + } - FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = gv * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + FLOAT_TYPE kv_shard = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + kv_shard += g_exp[r] * s_shard[r] * k_reg[r]; + } + FLOAT_TYPE kv_col = reduce_partial(kv_shard); - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + FLOAT_TYPE attn_partial = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + FLOAT_TYPE attn_col = reduce_partial(attn_partial); + + if (lane == 0) { data_dst[attn_off + col] = attn_col * scale; } attn_off += S_V * H; - barrier(); } - [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + col * S_V + i] = state[i]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 9c2ea2e59..4f011c4ab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -1004,7 +1004,9 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}})); + string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 560fbb66b..16d7e945b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1948,6 +1948,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); return 0; } + ggml_backend_buffer_clear(buf_output.get(), 0); } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 70a2c01fb..5dc7794c3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1787,6 +1787,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters (GLM-OCR) ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1820,6 +1821,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1866,6 +1868,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -2040,6 +2043,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); switch (hparams.n_layer) { case 32: type = LLM_TYPE_30B_A3B; break; @@ -2168,7 +2172,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_embd) { case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; case 2048: case 2560: type = LLM_TYPE_3B; break; case 4096: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -2222,6 +2226,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; diff --git a/tools/parser/debug-template-parser.cpp b/tools/parser/debug-template-parser.cpp index ffa3a5af7..a83797157 100644 --- a/tools/parser/debug-template-parser.cpp +++ b/tools/parser/debug-template-parser.cpp @@ -282,7 +282,7 @@ static void render_scenario(const common_chat_template & tmpl, LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str()); try { - autoparser::templates_params inputs; + autoparser::generation_params inputs; inputs.messages = final_messages; inputs.add_generation_prompt = add_generation_prompt; inputs.extra_context["enable_thinking"] = enable_thinking; @@ -395,7 +395,7 @@ int main(int argc, char ** argv) { analysis.analyze_template(chat_template); // Generate Parser - autoparser::templates_params params; + autoparser::generation_params params; params.messages = json::array({ build_user_message() }); params.reasoning_format = opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE; diff --git a/tools/parser/template-analysis.cpp b/tools/parser/template-analysis.cpp index a92e104ac..bf898a229 100644 --- a/tools/parser/template-analysis.cpp +++ b/tools/parser/template-analysis.cpp @@ -400,12 +400,12 @@ static void analyze_template(const std::string & template_path) { { json user_msg = make_user_msg(); - autoparser::templates_params params_no_tools; + autoparser::generation_params params_no_tools; params_no_tools.messages = json::array({ user_msg }); params_no_tools.add_generation_prompt = false; params_no_tools.tools = json::array(); - autoparser::templates_params params_with_tools = params_no_tools; + autoparser::generation_params params_with_tools = params_no_tools; params_with_tools.tools = tools; std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools); @@ -419,12 +419,12 @@ static void analyze_template(const std::string & template_path) { { json user_msg = make_user_msg(); - autoparser::templates_params params_no_prompt; + autoparser::generation_params params_no_prompt; params_no_prompt.messages = json::array({ user_msg }); params_no_prompt.add_generation_prompt = false; params_no_prompt.tools = json::array(); - autoparser::templates_params params_with_prompt = params_no_prompt; + autoparser::generation_params params_with_prompt = params_no_prompt; params_with_prompt.add_generation_prompt = true; std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt); @@ -438,12 +438,12 @@ static void analyze_template(const std::string & template_path) { { json user_msg = make_user_msg(); - autoparser::templates_params params_no_reasoning; + autoparser::generation_params params_no_reasoning; params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() }); params_no_reasoning.add_generation_prompt = false; params_no_reasoning.enable_thinking = true; - autoparser::templates_params params_with_reasoning = params_no_reasoning; + autoparser::generation_params params_with_reasoning = params_no_reasoning; params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() }); std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); @@ -458,12 +458,12 @@ static void analyze_template(const std::string & template_path) { json user_msg = make_user_msg(); json user_msg2 = make_user_msg2(); - autoparser::templates_params params_no_reasoning; + autoparser::generation_params params_no_reasoning; params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 }); params_no_reasoning.add_generation_prompt = false; params_no_reasoning.enable_thinking = true; - autoparser::templates_params params_with_reasoning = params_no_reasoning; + autoparser::generation_params params_with_reasoning = params_no_reasoning; params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 }); std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); @@ -477,12 +477,12 @@ static void analyze_template(const std::string & template_path) { { json user_msg = make_user_msg(); - autoparser::templates_params params_no_tool; + autoparser::generation_params params_no_tool; params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() }); params_no_tool.add_generation_prompt = false; params_no_tool.tools = tools; - autoparser::templates_params params_with_tool = params_no_tool; + autoparser::generation_params params_with_tool = params_no_tool; params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); @@ -497,12 +497,12 @@ static void analyze_template(const std::string & template_path) { json user_msg = make_user_msg(); json user_msg2 = make_user_msg2_continue(); - autoparser::templates_params params_no_tool; + autoparser::generation_params params_no_tool; params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 }); params_no_tool.add_generation_prompt = false; params_no_tool.tools = tools; - autoparser::templates_params params_with_tool = params_no_tool; + autoparser::generation_params params_with_tool = params_no_tool; params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); @@ -516,12 +516,12 @@ static void analyze_template(const std::string & template_path) { { json user_msg = make_user_msg(); - autoparser::templates_params params_one_tool; + autoparser::generation_params params_one_tool; params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); params_one_tool.add_generation_prompt = false; params_one_tool.tools = tools; - autoparser::templates_params params_two_tools = params_one_tool; + autoparser::generation_params params_two_tools = params_one_tool; params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() }); std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); @@ -536,12 +536,12 @@ static void analyze_template(const std::string & template_path) { json user_msg = make_user_msg(); json user_msg2 = make_user_msg2_continue(); - autoparser::templates_params params_one_tool; + autoparser::generation_params params_one_tool; params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); params_one_tool.add_generation_prompt = false; params_one_tool.tools = tools; - autoparser::templates_params params_two_tools = params_one_tool; + autoparser::generation_params params_two_tools = params_one_tool; params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 }); std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); @@ -555,13 +555,13 @@ static void analyze_template(const std::string & template_path) { { json user_msg = make_user_msg(); - autoparser::templates_params params_no_reasoning; + autoparser::generation_params params_no_reasoning; params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() }); params_no_reasoning.add_generation_prompt = false; params_no_reasoning.tools = tools; params_no_reasoning.enable_thinking = true; - autoparser::templates_params params_with_reasoning = params_no_reasoning; + autoparser::generation_params params_with_reasoning = params_no_reasoning; params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() }); std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 3fea3042f..5f82e35d6 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -4,6 +4,36 @@ This document provides an in-depth technical overview of `llama-server`, intende If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead. +## Scope of features + +In-scope types of feature: + +- Backend: + - Basic inference features: text completion, embeddings output + - Chat-oriented features: chat completion, tool calling + - Third-party API compatibility, e.g. OAI-compat, Anthropic-compat + - Multimodal input/output + - Memory management: save/load state, context checkpoints + - Model management + - Features that are required by the Web UI +- Frontend: + - Chat-oriented features, example: basic chat, image upload, edit messages + - Agentic features, example: MCP + - Model management + +Note: For security reasons, features that require reading or writing external files must be **disabled by default**. This covers features like: MCP, model save/load + +Out-of-scope features: + +- Backend: + - Features that require a loop of external API calls, e.g. server-side agentic loop. This is because external API calls in C++ are costly to maintain. Any complex third-party logic should be implemented outside of server code. + - Features that expose the internal state of the model to the API, example: getting the intermediate activation from API. This is because llama.cpp doesn't support a stable API for doing this, and relying on `eval_callback` can make it complicated to maintain as this API is not intended to be used in multi-sequence setup. + - Model-specific features. All API calls and features must remain model-agnostic. +- Frontend: + - Third-party plugins, it is costly to maintain a public plugin API for such features. Instead, users can make their own MCP server for their needs. + - Customizable themes, it is also costly to maintain. While we do focus on the aesthetic, we try to achieve this by perfecting a small set of themes. + - Browser-specific features, example: [Chrome's built-in AI API](https://developer.chrome.com/docs/ai/built-in-apis). + ## Backend ### Overview diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index f1ccf5a75..11d31b0f6 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 59ea11fc4..e01c8c53d 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1081,20 +1081,21 @@ json oaicompat_chat_params_parse( } } - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; if (!chat_params.grammar.empty()) { - llama_params["grammar"] = chat_params.grammar; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_type"] = std::string("tool_calls"); } - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { server_grammar_trigger ct(trigger); grammar_triggers.push_back(ct.to_json()); } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + llama_params["generation_prompt"] = chat_params.generation_prompt; for (const auto & stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } @@ -1114,7 +1115,6 @@ json oaicompat_chat_params_parse( llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag; llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag; llama_params["reasoning_budget_message"] = opt.reasoning_budget_message; - llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open; } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 1e5ff101c..9de554e90 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -1152,11 +1153,11 @@ private: // initialize samplers if (task.need_sampling()) { - slot.smpl.reset(common_sampler_init(model, task.params.sampling)); - - if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + try { + slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + } catch (std::exception & e) { + std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); + send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1431,9 +1432,10 @@ private: res->tokens = { tkn.tok }; } - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->post_sampling_probs = slot.task->params.post_sampling_probs; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; + res->post_sampling_probs = slot.task->params.post_sampling_probs; res->verbose = slot.task->params.verbose; res->res_type = slot.task->params.res_type; @@ -1478,14 +1480,15 @@ private: res->prompt = slot.task->tokens.detokenize(ctx, true); res->response_fields = std::move(slot.task->params.response_fields); - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->n_tokens_cached = slot.prompt.n_tokens(); - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.task->params.post_sampling_probs; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_prompt_tokens_cache = slot.n_prompt_tokens_cache; + res->n_tokens_cached = slot.prompt.n_tokens(); + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.task->params.post_sampling_probs; res->verbose = slot.task->params.verbose; res->stream = slot.task->params.stream; @@ -2304,8 +2307,8 @@ private: llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); - // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 - const auto n_swa = std::max(1, llama_model_n_swa(model)); + // note: when n_swa == 0, the model does not use SWA + const auto n_swa = std::max(0, llama_model_n_swa(model)); // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, pos_next - n_swa); @@ -2360,7 +2363,7 @@ private: SLT_WRN(slot, "%s\n", st1.str().c_str()); } - if (pos_min > pos_min_thold) { + if (pos_min >= pos_min_thold) { SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint @@ -2456,8 +2459,39 @@ private: slot.n_prompt_tokens_cache = 0; } + // If using an alora, there may be uncached tokens that come + // before the invocation sequence. When this happens, the + // tokens before the invocation sequence need to be + // processed without the adapter in a separate batch, then + // the adapter needs to be enabled for the remaining tokens. + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + const auto & enabled_loras = lora_get_enabled_ids(slot.lora); + GGML_ASSERT(enabled_loras.size() == 1); + alora_scale = slot.lora[enabled_loras[0]].scale; + slot.lora[enabled_loras[0]].scale = 0.0f; + alora_disabled_id = enabled_loras[0]; + } + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + + bool has_mtmd = false; + // check if we should process the image if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image @@ -2478,38 +2512,9 @@ private: slot.prompt.tokens.push_back(chunk.get()); // copy } - do_checkpoint = false; // do not checkpoint right after an image chunk + has_mtmd = true; } - // If using an alora, there may be uncached tokens that come - // before the invocation sequence. When this happens, the - // tokens before the invocation sequence need to be - // processed without the adapter in a separate batch, then - // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - const auto & enabled_loras = lora_get_enabled_ids(slot.lora); - GGML_ASSERT(enabled_loras.size() == 1); - alora_scale = slot.lora[enabled_loras[0]].scale; - slot.lora[enabled_loras[0]].scale = 0.0f; - alora_disabled_id = enabled_loras[0]; - } - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); - // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process @@ -2541,13 +2546,13 @@ private: // - 4 + n_ubatch // - 4 // ref: https://github.com/ggml-org/llama.cpp/pull/20288 - { + if (do_checkpoint) { static const int checkpoint_offsets[] = {4 + n_ubatch, 4}; bool should_break = false; for (int offset : checkpoint_offsets) { const int n_last = std::min(n_batch, offset); - if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { + if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { should_break = true; break; } @@ -2604,10 +2609,13 @@ private: const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); + + // do not checkpoint after mtmd chunks + do_checkpoint = do_checkpoint && !has_mtmd; // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); // note: we create the checkpoint before calling llama_decode(), so the current batch is not // yet processed and therefore it is not part of the checkpoint. diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index c13d48a36..4ac55cd15 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -539,6 +539,22 @@ void server_models::load(const std::string & name) { return; } + // Re-check capacity under the lock to prevent concurrent loads from + // exceeding models_max. Without this, the window between unload_lru() + // releasing its lock and this lock_guard acquiring allows multiple + // threads to each observe capacity and all proceed to load. + if (base_params.models_max > 0) { + size_t count_active = 0; + for (const auto & m : mapping) { + if (m.second.meta.is_active()) { + count_active++; + } + } + if (count_active >= (size_t)base_params.models_max) { + throw std::runtime_error("model limit reached, try again later"); + } + } + // prepare new instance info instance_t inst; inst.meta = meta; @@ -606,13 +622,20 @@ void server_models::load(const std::string & name) { }); std::thread stopping_thread([&]() { - // thread to monitor stopping signal + // thread to monitor stopping signal OR child crash auto is_stopping = [this, &name]() { return this->stopping_models.find(name) != this->stopping_models.end(); }; + auto should_wake = [&]() { + return is_stopping() || !subprocess_alive(child_proc.get()); + }; { std::unique_lock lk(this->mutex); - this->cv_stop.wait(lk, is_stopping); + this->cv_stop.wait(lk, should_wake); + } + // child may have already exited (e.g. crashed) — skip shutdown sequence + if (!subprocess_alive(child_proc.get())) { + return; } SRV_INF("stopping model instance name=%s\n", name.c_str()); // send interrupt to child process diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index b3d510977..7d543b929 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -72,7 +72,7 @@ json task_params::to_json(bool only_metrics) const { {"chat_format", common_chat_format_name(chat_parser_params.format)}, {"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)}, {"reasoning_in_content", chat_parser_params.reasoning_in_content}, - {"thinking_forced_open", chat_parser_params.thinking_forced_open}, + {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -128,14 +128,14 @@ json task_params::to_json(bool only_metrics) const { {"logit_bias", format_logit_bias(sampling.logit_bias)}, {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, - {"grammar", sampling.grammar}, + {"grammar", common_grammar_value(sampling.grammar)}, {"grammar_lazy", sampling.grammar_lazy}, {"grammar_triggers", grammar_triggers}, {"preserved_tokens", sampling.preserved_tokens}, {"chat_format", common_chat_format_name(chat_parser_params.format)}, {"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)}, {"reasoning_in_content", chat_parser_params.reasoning_in_content}, - {"thinking_forced_open", chat_parser_params.thinking_forced_open}, + {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -376,14 +376,25 @@ task_params server_task::params_from_json_cmpl( try { auto schema = json_value(data, "json_schema", json::object()); SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + std::string grammar_str = json_schema_to_grammar(schema); + SRV_DBG("Converted grammar: %s\n", grammar_str.c_str()); + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, std::move(grammar_str)}; } catch (const std::exception & e) { throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); } } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + std::string grammar_str = json_value(data, "grammar", std::string()); + if (!grammar_str.empty()) { + // grammar_type key is set by the server when converting chat template grammars + std::string grammar_type = json_value(data, "grammar_type", std::string()); + if (grammar_type == "tool_calls") { + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_TOOL_CALLS, std::move(grammar_str)}; + } else { + // explicit grammar from the user (API field "grammar") + params.sampling.grammar = {COMMON_GRAMMAR_TYPE_USER, std::move(grammar_str)}; + } + SRV_DBG("Grammar (%s): %s\n", grammar_type.c_str(), common_grammar_value(params.sampling.grammar).c_str()); + } params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); } @@ -402,7 +413,9 @@ task_params server_task::params_from_json_cmpl( } params.chat_parser_params.reasoning_format = reasoning_format; params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.chat_parser_params.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.chat_parser_params.generation_prompt = json_value(data, "generation_prompt", std::string()); + params.sampling.generation_prompt = params.chat_parser_params.generation_prompt; + SRV_DBG("Generation prompt: '%s'\n", params.chat_parser_params.generation_prompt.c_str()); params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false); if (data.contains("chat_parser")) { params.chat_parser_params.parser.load(data.at("chat_parser").get()); @@ -469,10 +482,7 @@ task_params server_task::params_from_json_cmpl( const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string()); const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string()); const auto message = json_value(data, "reasoning_budget_message", std::string()); - const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false); - params.sampling.reasoning_budget_tokens = budget; - params.sampling.reasoning_budget_activate_immediately = activate_imm; if (!start_tag.empty()) { params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true); @@ -482,8 +492,8 @@ task_params server_task::params_from_json_cmpl( params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true); } - SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n", - budget, activate_imm ? "true" : "false", + SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n", + budget, params.sampling.generation_prompt.c_str(), params.sampling.reasoning_budget_start.size(), params.sampling.reasoning_budget_end.size(), params.sampling.reasoning_budget_forced.size()); @@ -746,6 +756,15 @@ json server_task_result_cmpl_final::to_json_non_oaicompat() { return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } +json server_task_result_cmpl_final::usage_json_oaicompat() { + return json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + {"prompt_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, + }; +} + json server_task_result_cmpl_final::to_json_oaicompat() { std::time_t t = std::time(0); json logprobs = json(nullptr); // OAI default to null @@ -771,11 +790,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id} }; @@ -823,11 +838,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, + {"usage", usage_json_oaicompat()}, {"id", oaicompat_cmpl_id} }; @@ -892,11 +903,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, + {"usage", usage_json_oaicompat()}, }); } @@ -975,6 +982,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() { {"input_tokens", n_prompt_tokens}, {"output_tokens", n_decoded}, {"total_tokens", n_decoded + n_prompt_tokens}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, }}, }; @@ -1083,7 +1091,8 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { {"usage", json { {"input_tokens", n_prompt_tokens}, {"output_tokens", n_decoded}, - {"total_tokens", n_decoded + n_prompt_tokens} + {"total_tokens", n_decoded + n_prompt_tokens}, + {"input_tokens_details", json { {"cached_tokens", n_prompt_tokens_cache} }}, }} }}, }} @@ -1149,7 +1158,8 @@ json server_task_result_cmpl_final::to_json_anthropic() { {"stop_reason", stop_reason}, {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, {"usage", { - {"input_tokens", n_prompt_tokens}, + {"cache_read_input_tokens", n_prompt_tokens_cache}, + {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache}, {"output_tokens", n_decoded} }} }; @@ -1659,7 +1669,8 @@ json server_task_result_cmpl_partial::to_json_anthropic() { {"stop_reason", nullptr}, {"stop_sequence", nullptr}, {"usage", { - {"input_tokens", n_prompt_tokens}, + {"cache_read_input_tokens", n_prompt_tokens_cache}, + {"input_tokens", n_prompt_tokens - n_prompt_tokens_cache}, {"output_tokens", 0} }} }} diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 1e342531d..a49ddb594 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -344,6 +344,7 @@ struct server_task_result_cmpl_final : server_task_result { bool truncated; int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; int32_t n_tokens_cached; bool has_new_line; std::string stopping_word; @@ -387,6 +388,8 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_non_oaicompat(); + json usage_json_oaicompat(); + json to_json_oaicompat(); json to_json_oaicompat_chat(); @@ -408,6 +411,7 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; + int32_t n_prompt_tokens_cache; bool post_sampling_probs; bool is_progress = false; diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index d56a930f7..edef0a93b 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -51,6 +51,27 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert choice["finish_reason"] == finish_reason +def test_chat_completion_cached_tokens(): + global server + server.n_slots = 1 + server.start() + seq = [ + ("1 2 3 4 5 6", 77, 0), + ("1 2 3 4 5 6", 77, 76), + ("1 2 3 4 5 9", 77, 51), + ("1 2 3 9 9 9", 77, 47), + ] + for user_prompt, n_prompt, n_cache in seq: + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Test"}, + {"role": "user", "content": user_prompt}, + ], + }) + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["prompt_tokens_details"]["cached_tokens"] == n_cache + @pytest.mark.parametrize( "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ @@ -210,6 +231,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): global server server.jinja = jinja + server.debug = True server.start() res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predicted, diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index 93ff03be6..ef1948d4a 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -63,8 +63,10 @@ def test_anthropic_messages_basic(): assert "text" in res.body["content"][0], "Text content block missing 'text' field" assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}" assert "usage" in res.body, "Missing 'usage' field" + assert "cache_read_input_tokens" in res.body["usage"], "Missing usage.cache_read_input_tokens" assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens" assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens" + assert isinstance(res.body["usage"]["cache_read_input_tokens"], int), "cache_read_input_tokens should be integer" assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer" assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer" assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens" diff --git a/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts b/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts index ce91de741..cbb2605e1 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts @@ -51,7 +51,7 @@ describe('ParameterSyncService', () => { chat_format: '', reasoning_format: '', reasoning_in_content: false, - thinking_forced_open: false, + generation_prompt: '', 'speculative.n_max': 0, 'speculative.n_min': 0, 'speculative.p_min': 0.0, @@ -116,7 +116,7 @@ describe('ParameterSyncService', () => { chat_format: '', reasoning_format: '', reasoning_in_content: false, - thinking_forced_open: false, + generation_prompt: '', 'speculative.n_max': 0, 'speculative.n_min': 0, 'speculative.p_min': 0.0, diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index c90825842..7cbd6db97 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -164,7 +164,7 @@ export interface ApiLlamaCppServerProps { chat_format: string; reasoning_format: string; reasoning_in_content: boolean; - thinking_forced_open: boolean; + generation_prompt: string; samplers: string[]; backend_sampling: boolean; 'speculative.n_max': number; @@ -332,7 +332,7 @@ export interface ApiSlotData { chat_format: string; reasoning_format: string; reasoning_in_content: boolean; - thinking_forced_open: boolean; + generation_prompt: string; samplers: string[]; backend_sampling: boolean; 'speculative.n_max': number;