diff --git a/common/arg.cpp b/common/arg.cpp index 2900e860f..c7440728c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2810,7 +2810,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 6021fc4ed..db3a6cc6f 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -43,11 +43,33 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & const autoparser & autoparser) { // Create the result structure common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = autoparser.preserved_tokens; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = autoparser.preserved_tokens; - auto parser = autoparser.build_parser(inputs); + std::string parser_generation_prompt = data.generation_prompt; + + if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !inputs.continue_msg.empty()) { + // Build up generation prompt manually + const auto & msg = inputs.continue_msg; + + if (!autoparser.reasoning.start.empty()) { + data.generation_prompt = data.generation_prompt.substr(0, data.generation_prompt.find(autoparser.reasoning.start)); + data.generation_prompt += autoparser.reasoning.start + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += autoparser.reasoning.end; + } + } + + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + + auto parser = autoparser.build_parser(inputs, parser_generation_prompt); data.parser = parser.save(); // Build grammar if tools are present @@ -87,7 +109,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & return data; } -common_peg_arena autoparser::build_parser(const generation_params & inputs) const { +common_peg_arena autoparser::build_parser(const generation_params & inputs, const std::string & generation_prompt) const { if (!analysis_complete) { throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); } @@ -121,7 +143,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons } else { parser = content.build_parser(ctx); } - return pure_content ? p.prefix(inputs.generation_prompt, reasoning.start) + parser : p.prefix(inputs.generation_prompt, reasoning.start) << parser; + return pure_content ? p.prefix(generation_prompt, reasoning.start) + parser : p.prefix(generation_prompt, reasoning.start) << parser; }); } diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 6c5474097..c680e6868 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -60,16 +60,21 @@ struct generation_params { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; bool stream = true; std::string grammar; - bool add_generation_prompt = false; - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - std::string generation_prompt; + bool add_generation_prompt = false; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + common_chat_msg continue_msg; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); json extra_context; bool add_bos = false; bool add_eos = false; bool is_inference = true; bool add_inference = false; bool mark_input = true; // whether to mark input strings in the jinja context + + bool has_continuation() const { + return continue_final_message != COMMON_CHAT_CONTINUATION_NONE && !continue_msg.empty(); + } }; // ============================================================================ @@ -386,7 +391,7 @@ struct autoparser { void analyze_template(const common_chat_template & tmpl); // Build the PEG parser for this template - common_peg_arena build_parser(const generation_params & inputs) const; + common_peg_arena build_parser(const generation_params & inputs, const std::string & generation_prompt) const; private: // Collect tokens from entire analysis to preserve diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index a4818859a..12e747d1c 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -358,35 +358,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) { if (is_potential_container) { value_content = normalize_container_value(value_content); } - - // Try to parse as JSON value (number, bool, null, object, array) - try { - ordered_json parsed = ordered_json::parse(value_content); - if (parsed.is_string()) { - // Don't add closing quote yet (added by arg_close) for monotonic streaming - std::string escaped = parsed.dump(); - if (!escaped.empty() && escaped.back() == '"') { - escaped.pop_back(); - } - value_to_add = escaped; - closing_quote_pending = true; - } else { - // Non-string values: use raw content to preserve whitespace for monotonicity - value_to_add = value_content; - } - } catch (...) { - if (node.is_partial && is_potential_container) { - // Partial container: pass through the already-normalized content - value_to_add = value_content; - } else { - // Not valid JSON - treat as string value - if (!closing_quote_pending) { - value_to_add = "\""; - closing_quote_pending = true; - } - value_to_add += escape_json_string_inner(value_content); - } - } + value_to_add += value_content; } args_target() += value_to_add; @@ -813,7 +785,7 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s if (delimiter.empty()) { return literal(s); } - return literal(s.substr(0, s.rfind(delimiter))); + return literal(s.substr(0, s.find(delimiter))); } common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) { diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index c684d7735..be92f17d9 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -90,7 +90,7 @@ class common_chat_peg_builder : public common_peg_parser_builder { // Use for schema-declared string types - won't be treated as potential JSON container common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } - common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); } // Return a parser that parses the prefix of a string, up to a given delimiter. diff --git a/common/chat.cpp b/common/chat.cpp index c307c7bf7..1d4ded6d8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -81,6 +81,26 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) { return !msg.content.empty() || !msg.tool_calls.empty(); } +std::string common_chat_msg::render_content(const std::string & delimiter) const { + if (!content.empty() && !content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + if (!content.empty()) { + return content; + } + + std::string text; + for (const auto & part : content_parts) { + if (part.type == "text") { + if (!text.empty()) { + text += delimiter; + } + text += part.text; + } + } + return text; +} + json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { if (!content.empty() && !content_parts.empty()) { throw std::runtime_error("Cannot specify both content and content_parts"); @@ -465,6 +485,23 @@ std::vector common_chat_tools_parse_oaicompat(const json & too #include "common/unicode.h" #include "peg-parser.cpp" #include "chat-peg-parser.cpp" + +common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value) { + if (value.is_boolean() && value.get()) { + return COMMON_CHAT_CONTINUATION_AUTO; + } + if (value.is_string()) { + auto value_str = value.get(); + if (value_str == "reasoning_content") { + return COMMON_CHAT_CONTINUATION_REASONING; + } + if (value_str == "content") { + return COMMON_CHAT_CONTINUATION_CONTENT; + } + } + return COMMON_CHAT_CONTINUATION_NONE; +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -825,6 +862,36 @@ std::string common_chat_template_direct_apply( return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); } +static std::string common_chat_template_generation_prompt_impl( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs, + const std::optional & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt) { + + auto adjusted_messages = messages_override ? *messages_override : inputs.messages; + + autoparser::generation_params params = inputs; + params.add_generation_prompt = false; + params.continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context); + params.add_generation_prompt = true; + std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params, adjusted_messages, tools_override, additional_context); + + size_t prefix_len = 0; + size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); + while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { + prefix_len++; + } + return gen_prompt.substr(prefix_len); +} + +std::string common_chat_template_generation_prompt( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + return common_chat_template_generation_prompt_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); +} + static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { common_chat_params data; @@ -877,6 +944,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ data.thinking_start_tag = "[THINK]"; data.thinking_end_tag = "[/THINK]"; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { "[THINK]", @@ -885,8 +953,19 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ "[ARGS]", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = "[THINK]" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "[/THINK]" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, "[THINK]"); + auto generation_prompt = p.eps(); auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); @@ -977,6 +1056,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp } data.prompt = prompt; + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -986,6 +1066,18 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp "<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>", }; + // Adjust prompt for continuation + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = "<|start|>assistant<|channel|>analysis<|message|>" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); @@ -1094,12 +1186,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); if (inputs.add_generation_prompt && string_ends_with(data.prompt, "\n")) { // This may happen if the model generates content + tool_call, the // template does not add the model's next turn and confuses the model // from emitting its proper reasoning token sequence. - data.prompt += "<|turn>model\n"; + data.generation_prompt = "<|turn>model\n"; + data.prompt += data.generation_prompt; } data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; @@ -1115,13 +1209,25 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ "<|turn>", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = string_ends_with(data.prompt, "\n") ? "<|turn>model\n" : ""; + data.generation_prompt += "<|channel>thought\n" + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += "" + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object(); auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto start = p.rule("start", p.prefix(inputs.generation_prompt, "<|channel>")); + auto start = p.rule("start", p.optional(p.literal("<|turn>model\n"))); if (extract_reasoning) { p.rule("thought", p.literal("<|channel>thought") + p.space() + p.reasoning(p.until("")) + p.literal("")); @@ -1238,15 +1344,22 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { ">>>all", }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + data.generation_prompt = "<|start_header_id|>assistant<|end_header_id|>\n\n>>>all\n" + msg.render_content(); + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Functionary v3.2 format: // - Normal content: >>>all\n{content} @@ -1258,7 +1371,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // When no tools, content goes until end auto content_until_tool = p.literal("all\n") + p.content(p.until(">>>")); auto content_until_end = p.literal("all\n") + p.content(p.rest()); - auto generation_prompt = p.literal(inputs.generation_prompt); + auto generation_prompt = p.literal("<|start_header_id|>assistant<|end_header_id|>\n\n>>>"); // If no tools or tool_choice is NONE, just parse content if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { @@ -1332,9 +1445,10 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; data.preserved_tokens = { "<|tool_calls_section_begin|>", "<|tool_calls_section_end|>", @@ -1357,10 +1471,22 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_assistant|>assistant<|im_middle|>"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Kimi K2 Thinking format: // - Reasoning: {reasoning} @@ -1380,7 +1506,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + p.optional(p.literal(THINK_END))) : p.eps(); - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); // Content only parser (no tools) @@ -1456,6 +1582,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1475,12 +1602,24 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat const std::string TOOL_CALL_END = "<|tool_call_end|>"; const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_start|>assistant\n"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -1535,6 +1674,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1550,12 +1690,24 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ const std::string THINK_START = ""; const std::string THINK_END = ""; + const std::string GEN_PROMPT = "<|im_start|>assistant\n"; data.thinking_start_tag = THINK_START; data.thinking_end_tag = THINK_END; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -1606,6 +1758,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( common_chat_params data; data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = false; data.preserved_tokens = { @@ -1613,6 +1766,12 @@ static common_chat_params common_chat_params_init_gigachat_v3( "<|role_sep|>\n", }; + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + data.generation_prompt = "assistant<|role_sep|>\n" + msg.render_content(); + data.prompt += data.generation_prompt; + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; const auto *tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n"; @@ -1648,7 +1807,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( ret = p.content(p.rest()); } - return p.literal(inputs.generation_prompt) + ret; + return p.literal("assistant<|role_sep|>\n") + ret; }); data.parser = parser.save(); @@ -1676,12 +1835,13 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; data.thinking_start_tag = ""; data.thinking_end_tag = ""; - data.preserved_tokens = { + data.preserved_tokens = { "|DSML|", "", "", @@ -1701,9 +1861,21 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha const std::string INVOKE_END = ""; const std::string PARAM_START = "<" + DSML + "parameter"; const std::string PARAM_END = ""; + const std::string GEN_PROMPT = "<|Assistant|>"; + + if (inputs.has_continuation()) { + const auto & msg = inputs.continue_msg; + + data.generation_prompt = GEN_PROMPT + THINK_START + msg.reasoning_content; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) { + data.generation_prompt += THINK_END + msg.render_content(); + } + + data.prompt += data.generation_prompt; + } auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { - auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto generation_prompt = p.literal(GEN_PROMPT); auto end = p.end(); auto reasoning = p.eps(); @@ -2130,21 +2302,6 @@ std::optional common_chat_try_specialized_template( return std::nullopt; } -static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { - autoparser::generation_params params = inputs; - params.add_generation_prompt = false; - std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - params.add_generation_prompt = true; - std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); - - size_t prefix_len = 0; - size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size()); - while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) { - prefix_len++; - } - return gen_prompt.substr(prefix_len); -} - static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { autoparser::generation_params params; @@ -2163,6 +2320,27 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ params.add_bos = tmpls->add_bos; params.add_eos = tmpls->add_eos; + params.continue_final_message = inputs.continue_final_message; + if (params.continue_final_message != COMMON_CHAT_CONTINUATION_NONE) { + params.add_generation_prompt = false; + + if (!inputs.messages.empty()) { + // Render messages[:-1] and store continuation message separately + params.continue_msg = inputs.messages.back(); + params.messages.erase(params.messages.size() - 1); + } + + if (params.continue_final_message == COMMON_CHAT_CONTINUATION_AUTO && !inputs.messages.empty()) { + // Resolve based on message content + params.continue_final_message = COMMON_CHAT_CONTINUATION_CONTENT; + if (!params.continue_msg.reasoning_content.empty() && + params.continue_msg.content.empty() && + params.continue_msg.content_parts.empty()) { + params.continue_final_message = COMMON_CHAT_CONTINUATION_REASONING; + } + } + } + if (src.find("<|channel|>") == std::string::npos) { // map developer to system for all models except for GPT-OSS workaround::map_developer_role_to_system(params.messages); @@ -2183,8 +2361,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ workaround::func_args_not_string(params.messages); } - params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params); - params.extra_context = common_chat_extra_context(); for (auto el : inputs.chat_template_kwargs) { params.extra_context[el.first] = json::parse(el.second); @@ -2214,17 +2390,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto params_copy = params; params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE; data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, params); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.generation_prompt = params.generation_prompt; - auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { - return p.prefix(params.generation_prompt) << p.content(p.rest()); + auto parser = build_chat_peg_parser([&data](common_chat_peg_builder &p) { + return p.literal(data.generation_prompt) << p.content(p.rest()); }); data.parser = parser.save(); return data; } if (auto result = common_chat_try_specialized_template(tmpl, src, params)) { - result->generation_prompt = params.generation_prompt; return *result; } @@ -2238,7 +2413,6 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start); auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end); } - auto_params.generation_prompt = params.generation_prompt; common_peg_arena arena; arena.load(auto_params.parser); LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str()); diff --git a/common/chat.h b/common/chat.h index 054f5ffe7..8ace3e6ba 100644 --- a/common/chat.h +++ b/common/chat.h @@ -89,6 +89,8 @@ struct common_chat_msg { nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const; + std::string render_content(const std::string & delimiter = "\n\n") const; + bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); @@ -164,12 +166,22 @@ enum common_chat_format { COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; + +// Continuation method provided via `continue_final_message` +enum common_chat_continuation { + COMMON_CHAT_CONTINUATION_NONE, + COMMON_CHAT_CONTINUATION_AUTO, + COMMON_CHAT_CONTINUATION_REASONING, + COMMON_CHAT_CONTINUATION_CONTENT, +}; + struct common_chat_templates_inputs { std::vector messages; std::string grammar; std::string json_schema; - bool add_generation_prompt = true; - bool use_jinja = true; + bool add_generation_prompt = true; + common_chat_continuation continue_final_message = COMMON_CHAT_CONTINUATION_NONE; + bool use_jinja = true; // Parameters below only supported when use_jinja is true std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; @@ -207,6 +219,7 @@ struct common_chat_parser_params { bool reasoning_in_content = false; std::string generation_prompt; bool parse_tool_calls = true; + bool echo = false; // Include assistant prefilled msg in output bool debug = false; // Enable debug output for PEG parser common_peg_arena parser = {}; common_chat_parser_params() = default; @@ -267,6 +280,8 @@ std::vector common_chat_msgs_parse_oaicompat(const nlohmann::or std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); +common_chat_continuation common_chat_continuation_parse(const nlohmann::ordered_json & value); + // DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); @@ -279,6 +294,10 @@ std::string common_chat_template_direct_apply( const common_chat_template & tmpl, const autoparser::generation_params & inputs); +std::string common_chat_template_generation_prompt( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs); + std::optional common_chat_try_specialized_template( const common_chat_template & tmpl, const std::string & src, diff --git a/common/common.cpp b/common/common.cpp index b89e8619a..cf355f66f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -379,7 +379,7 @@ void common_init() { llama_log_set(common_log_default_callback, NULL); } -void common_params_print_info(const common_params & params) { +void common_params_print_info(const common_params & params, bool print_devices) { #ifdef NDEBUG const char * build_type = ""; #else @@ -388,12 +388,16 @@ void common_params_print_info(const common_params & params) { LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type); LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold()); - LOG_INF("device_info:\n"); - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - auto * dev = ggml_backend_dev_get(i); - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + + // device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device + if (print_devices) { + LOG_INF("device_info:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } } LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } diff --git a/common/common.h b/common/common.h index 867d66c30..e0f6b6780 100644 --- a/common/common.h +++ b/common/common.h @@ -618,8 +618,6 @@ struct common_params { // UI configs #ifdef LLAMA_UI_DEFAULT_ENABLED bool ui = LLAMA_UI_DEFAULT_ENABLED != 0; -#elif defined(LLAMA_WEBUI_DEFAULT_ENABLED) - bool ui = LLAMA_WEBUI_DEFAULT_ENABLED != 0; #else bool ui = true; // default to enabled when not set #endif @@ -709,7 +707,7 @@ struct common_params { // initializes the logging system and prints info about the build void common_init(); -void common_params_print_info(const common_params & params); +void common_params_print_info(const common_params & params, bool print_devices = true); std::string common_params_get_system_info(const common_params & params); bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]); diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index 8e3978f7e..02bc482fe 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -471,7 +471,7 @@ void common_ngram_map_draft(common_ngram_map & map, sum_occur += curr_occur; } - LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__, + LOG_DBG("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__, key_offset, max_occur, sum_occur, slot_max, curr_key.values[0].value_idx, curr_key.values[0].value_num, @@ -482,7 +482,7 @@ void common_ngram_map_draft(common_ngram_map & map, // Print the tokens of the four values (if idx != 0), use LOG_INF for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { if (curr_key.values[v].value_idx != 0) { - LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str()); + LOG_DBG("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str()); } } diff --git a/common/speculative.cpp b/common/speculative.cpp index 23a500141..37b58d8af 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -146,8 +146,11 @@ struct common_speculative_impl { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; - // true if this implementation requires the target context to extract embeddings + // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; + + // true if this implementation requires the target context to extract pre-norm embeddings + virtual bool need_embd_pre_norm() const { return false; } }; struct common_speculative_impl_draft_simple : public common_speculative_impl { @@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } - llama_set_embeddings_pre_norm(ctx_tgt, true); - llama_set_embeddings_pre_norm(ctx_dft, true); + llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false); + llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true); pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); @@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { } bool need_embd() const override { + return false; + } + + bool need_embd_pre_norm() const override { return true; } }; @@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) { return false; } +bool common_speculative_need_embd_pre_norm(common_speculative * spec) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->need_embd_pre_norm()) { + return true; + } + } + + return false; +} + void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 614db9b1b..f24bac79e 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co // process the batch and update the internal state of the speculative context bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// true if any implementation requires target embeddings to be extracted +// true if any implementation requires target post-norm embeddings to be extracted bool common_speculative_need_embd(common_speculative * spec); +// true if any implementation requires target pre-norm embeddings to be extracted +bool common_speculative_need_embd_pre_norm(common_speculative * spec); + // generate drafts for the sequences specified with `common_speculative_get_draft_params` void common_speculative_draft(common_speculative * spec); diff --git a/conversion/qwen.py b/conversion/qwen.py index 4b8640426..45d1f98c2 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -600,6 +600,7 @@ class _Qwen35MtpMixin: if name.find("layers.") != -1: assert bid is not None name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}") + bid = bid + n_layer else: remapper = { "mtp.fc": "model.layers.{bid}.eh_proj", diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fb..4c4daf85d 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1..db1d39e2d 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f71ef7667..348b52270 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -50,7 +50,6 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include -#include #include #include #include @@ -765,8 +764,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; @@ -860,6 +859,8 @@ struct vk_device_struct { vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; + vk_pipeline pipeline_ssm_conv_silu_f32; + vk_pipeline pipeline_ssm_conv_bias_silu_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -1358,6 +1359,8 @@ struct vk_op_rope_push_constants { uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t a_offset; + uint32_t d_offset; }; static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); @@ -4574,6 +4577,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4582,6 +4586,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4906,7 +4911,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -7566,6 +7573,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_bf16_f32; + } else { + return ctx->device->pipeline_cpy_bf16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_i32; @@ -9964,7 +9978,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SSM_CONV: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_ssm_conv_f32; + switch (ctx->num_additional_fused_ops) { + case 0: return ctx->device->pipeline_ssm_conv_f32; + case 1: return ctx->device->pipeline_ssm_conv_silu_f32; + case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32; + default: return nullptr; + } } return nullptr; case GGML_OP_OPT_STEP_ADAMW: @@ -10145,6 +10164,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + template static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; @@ -10905,11 +10933,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, pc, elements); } -static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { + ggml_tensor * conv = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = conv->src[0]; + const ggml_tensor * src1 = conv->src[1]; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + // Pick the destination tensor (last node in the fused chain) and the optional bias. + // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu. + ggml_tensor * dst = conv; + const ggml_tensor * bias = nullptr; + + if (ctx->num_additional_fused_ops == 1) { + dst = cgraph->nodes[node_idx + 1]; // silu + } else if (ctx->num_additional_fused_ops == 2) { + ggml_tensor * add = cgraph->nodes[node_idx + 1]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + dst = cgraph->nodes[node_idx + 2]; // silu + } + + // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused. + const ggml_tensor * src2 = bias ? bias : src0; + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -11272,6 +11317,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * (uint32_t)src0->ne[2], nb01, nb02, nb03, nb11, nb12, nb13, + 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets }; return rope; @@ -11367,6 +11413,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(buf[i] != nullptr); } + // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned. + // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset. + pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type); + offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1); + std::array elements; elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; @@ -13584,7 +13635,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SSM_CONV: - ggml_vk_ssm_conv(ctx, compute_ctx, node); + ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx); break; @@ -14481,6 +14532,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return true; } +// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2. +static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, int num_extra) { + const ggml_tensor * conv = cgraph->nodes[node_idx]; + if (conv->op != GGML_OP_SSM_CONV) { + return false; + } + + const ggml_tensor * silu = nullptr; + const ggml_tensor * bias = nullptr; + + if (num_extra == 1) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) { + return false; + } + silu = cgraph->nodes[node_idx + 1]; + } else if (num_extra == 2) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) { + return false; + } + const ggml_tensor * add = cgraph->nodes[node_idx + 1]; + silu = cgraph->nodes[node_idx + 2]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + // bias must be channel-wise (one element per channel of the conv output) + if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) { + return false; + } + if (add->type != GGML_TYPE_F32) { + return false; + } + // The shader doesn't apply per-tensor offsets, so reject misaligned bias. + if (get_misalign_bytes(ctx, bias) != 0) { + return false; + } + } else { + return false; + } + + if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) { + return false; + } + if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + // The shader writes to the fused dst using its own strides, but the push constants don't + // carry a per-tensor offset, so the binding must be naturally aligned. + if (get_misalign_bytes(ctx, silu) != 0) { + return false; + } + return true; +} + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode) { @@ -14897,6 +15004,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // they are overwritten, and one workgroup per row. So close enough. op_srcs_fused_elementwise[0] = true; op_srcs_fused_elementwise[1] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) { + ctx->num_additional_fused_ops = 2; + fusion_string = "SSM_CONV_BIAS_SILU"; + // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs. + // The downstream add and silu are elementwise on the conv output. + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) { + ctx->num_additional_fused_ops = 1; + fusion_string = "SSM_CONV_SILU"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -15228,7 +15348,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) { ok = false; break; } @@ -15311,6 +15433,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } } + // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_SSM_CONV) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_UNARY && + graph->nodes[k]->src[0] == graph->nodes[j]) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. @@ -15875,6 +16010,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index ca1a3ac25..b3b182fb0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -19,7 +19,9 @@ void main() { if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) @@ -35,7 +37,9 @@ void main() { continue; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 9f8bfd3c1..d55e13253 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,9 @@ void main() { return; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 2e5345990..033587931 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); @@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop return; } - const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 2e2a7e14c..3602485b9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -26,6 +26,9 @@ struct rope_params { uint nb11; uint nb12; uint nb13; + + uint a_offset; + uint d_offset; }; #endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index 6802b1fc9..4cd9b8da3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -6,12 +6,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint TOKENS_PER_WG = 16; +layout(constant_id = 2) const bool APPLY_BIAS = false; +layout(constant_id = 3) const bool APPLY_SILU = false; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; -layout(binding = 2) buffer Dst { float dst[]; }; +layout(binding = 2) readonly buffer Bias { float bias[]; }; +layout(binding = 3) buffer Dst { float dst[]; }; layout(push_constant) uniform PushConstants { uint nb01; uint nb02; @@ -45,6 +48,13 @@ void main() { } } + if (APPLY_BIAS) { + sum += bias[i1]; + } + if (APPLY_SILU) { + sum = sum / (1.0f + exp(-sum)); + } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8210e221f..cf610730a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -748,6 +748,7 @@ void process_shaders() { string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -755,6 +756,7 @@ void process_shaders() { string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); diff --git a/koboldcpp.py b/koboldcpp.py index 3ff430a35..a2d021041 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -74,7 +74,7 @@ dry_seq_break_max = 128 extra_images_max = 4 # for kontext/qwen img # global vars -KcppVersion = "1.113.1" +KcppVersion = "1.114" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_base_config":"", "last_active_timestamp":datetime.now(), "triggered_sleeping":False, "current_model":"initial_model", "base_config":"", "swapReqType": None, "autoswapmode": False} diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 32ed65b49..14834781e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -67,8 +67,9 @@ llama_context::llama_context( cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; - cparams.embeddings_pre_norm = false; + cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; + cparams.embeddings_pre_norm_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -905,8 +906,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { throw std::runtime_error("no pre-norm embeddings"); } - const int64_t j = output_resolve_row(i); const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); return embd_pre_norm.data + j*n_embd; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); @@ -1098,10 +1108,11 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } -void llama_context::set_embeddings_pre_norm(bool value) { - LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); - cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; } void llama_context::set_causal_attn(bool value) { @@ -1747,6 +1758,7 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; + int64_t n_tokens_prev = 0; do { const auto & ubatch = mctx->get_ubatch(); @@ -1892,16 +1904,21 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract pre-norm embeddings (hidden state before the final output norm) // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. - if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); - GGML_ASSERT(backend_h != nullptr); + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; - const uint32_t n_embd = hparams.n_embd; - float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd; + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float)); + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } } // Copy backend sampling output if this ubatch produced any sampling tensors. @@ -1918,6 +1935,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -2009,6 +2027,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd.size = has_embd ? n_embd_out*n_outputs_max : 0; embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } + // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { @@ -3557,8 +3581,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) { - ctx->set_embeddings_pre_norm(value); +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); } float * llama_get_embeddings_pre_norm(llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index e16ac4c61..d03f681d4 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -110,7 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); - void set_embeddings_pre_norm(bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 5898a1c38..20ec59fe3 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -28,7 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; - bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 11f198667..edfa71c20 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // pre-norm embeddings (hidden state before the final output norm) // -// mirrors: -// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); -LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value); +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); -// mirrors: // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e94507a13..a9a5a8e0c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 2b4d5b14c..361d7538a 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "h_pre_norm", -1); res->t_h_pre_norm = cur; + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 22e3e1107..4f63c410d 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "h_pre_norm", -1); res->t_h_pre_norm = cur; + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 73de0d3bb..dc00edfa8 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1032,23 +1032,33 @@ json oaicompat_chat_params_parse( auto caps = common_chat_templates_get_caps(opt.tmpls.get()); common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.grammar = grammar; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - const bool continue_final_message = json_value(body, "continue_final_message", false); - if (continue_final_message && inputs.add_generation_prompt) { + inputs.messages = common_chat_msgs_parse_oaicompat(messages); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; + inputs.use_jinja = opt.use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]); + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); + inputs.continue_final_message = body.contains("continue_final_message") ? + common_chat_continuation_parse(body.at("continue_final_message")) : + COMMON_CHAT_CONTINUATION_NONE; + if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_NONE && opt.prefill_assistant + && !inputs.messages.empty() && inputs.messages.back().role == "assistant") { + if (inputs.messages.size() >= 2 && inputs.messages[inputs.messages.size() - 2].role == "assistant") { + throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); + } + inputs.continue_final_message = COMMON_CHAT_CONTINUATION_AUTO; + inputs.add_generation_prompt = false; + } + if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && inputs.add_generation_prompt) { throw std::invalid_argument("Cannot set both add_generation_prompt and continue_final_message to true."); } - inputs.reasoning_format = opt.reasoning_format; + inputs.reasoning_format = opt.reasoning_format; if (body.contains("reasoning_format")) { inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get()); } - inputs.enable_thinking = opt.enable_thinking; + inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { throw std::invalid_argument("Cannot use custom grammar constraints with tools."); @@ -1073,84 +1083,11 @@ json oaicompat_chat_params_parse( throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); } - // if the assistant message appears at the end of list, we do not add end-of-turn token - // for ex. this can be useful to modify the reasoning process in reasoning models - // continue_final_message is the explicit opt in alias from the vLLM/transformers API, - // equivalent to the prefill_assistant heuristic - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" - && (continue_final_message || opt.prefill_assistant); - common_chat_msg last_message; - if (prefill_assistant_message) { - last_message = inputs.messages.back(); - inputs.messages.pop_back(); - - /* sanity check, max one assistant message at the end of the list */ - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); - } - - // reject reasoning prefill on channel based templates that do not expose explicit thinking tags - if (!last_message.reasoning_content.empty() && inputs.enable_thinking) { - auto probe_params = common_chat_templates_apply(opt.tmpls.get(), inputs); - if (probe_params.supports_thinking && probe_params.thinking_end_tag.empty()) { - throw std::invalid_argument("Assistant prefill with reasoning_content is not supported yet for this template."); - } - } - - inputs.add_generation_prompt = true; - } inputs.force_pure_content = opt.force_pure_content; // Apply chat template to the list of messages auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs); - /* Append assistant prefilled message */ - if (prefill_assistant_message) { - const bool thinking_active = chat_params.supports_thinking && !chat_params.thinking_end_tag.empty(); - const bool has_reasoning = !last_message.reasoning_content.empty(); - const bool has_content = !last_message.content.empty() || !last_message.content_parts.empty(); - const bool mid_reasoning = has_reasoning && !has_content; - - // some templates inject thinking_start in generation_prompt, others let the model emit it - const bool gp_has_think = thinking_active - && chat_params.generation_prompt.find(chat_params.thinking_start_tag) != std::string::npos; - - // open the thinking block when reasoning is present and the template did not inject it - if (has_reasoning) { - if (thinking_active && !gp_has_think) { - chat_params.prompt += chat_params.thinking_start_tag; - } - chat_params.prompt += last_message.reasoning_content; - } - - if (thinking_active) { - if (mid_reasoning) { - // model continues inside the thinking block, keep generation_prompt open on think - if (!gp_has_think) { - chat_params.generation_prompt += chat_params.thinking_start_tag; - } - } else { - // close thinking block when reasoning is followed by content, or when the template forced it open - if (has_reasoning || gp_has_think) { - chat_params.prompt += chat_params.thinking_end_tag; - } - // strip thinking_start from generation_prompt so the parser routes model output as content - auto pos = chat_params.generation_prompt.rfind(chat_params.thinking_start_tag); - if (pos != std::string::npos) { - chat_params.generation_prompt = chat_params.generation_prompt.substr(0, pos); - } - } - } - - if (!last_message.content_parts.empty()) { - for (auto & p : last_message.content_parts) { - chat_params.prompt += p.text; - } - } else { - chat_params.prompt += last_message.content; - } - } - llama_params["chat_format"] = static_cast(chat_params.format); llama_params["prompt"] = chat_params.prompt; if (!chat_params.grammar.empty()) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 4d162f81d..0f3fb9efa 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -243,6 +243,11 @@ struct server_slot { return task->need_embd() || (spec && common_speculative_need_embd(spec)); } + bool need_embd_pre_norm() const { + GGML_ASSERT(task); + return spec && common_speculative_need_embd_pre_norm(spec); + } + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens // (MTP supports splitting — uses task->need_embd() not need_embd()) @@ -4527,7 +4532,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons } } - int embd_normalize = 2; // default to Euclidean/L2 norm + int embd_normalize = params.embd_normalize; if (body.count("embd_normalize") != 0) { embd_normalize = body.at("embd_normalize"); if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) { diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 39a21f4ec..9d008fc94 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -231,11 +231,10 @@ bool server_http_context::init(const common_params & params) { }; auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { - (void)req; // suppress unused parameter warning when LLAMA_BUILD_UI / LLAMA_BUILD_WEBUI is not defined + (void)req; // suppress unused parameter warning when LLAMA_BUILD_UI is not defined bool ready = is_ready.load(); if (!ready) { -// Support both old and new preprocessor defines -#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI) +#if defined(LLAMA_BUILD_UI) auto tmp = string_split(req.path, '.'); if (req.path == "/" || (tmp.size() > 0 && tmp.back() == "html")) { res.status = 503; @@ -313,8 +312,7 @@ bool server_http_context::init(const common_params & params) { return 1; } } else { -// Support both old and new preprocessor defines -#if defined(LLAMA_BUILD_UI) || defined(LLAMA_BUILD_WEBUI) +#if defined(LLAMA_BUILD_UI) // using embedded static index.html srv->Get(params.api_prefix + "/", [](const httplib::Request & /*req*/, httplib::Response & res) { // COEP and COOP headers, required by pyodide (python interpreter) diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 433d2d8f0..6c6fed52d 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -798,9 +798,10 @@ void server_models::load(const std::string & name) { std::thread log_thread([&]() { // read stdout/stderr and forward to main server log // also handle status report from child process + std::vector vec_buf(128 * 1024); // large buffer for storing info + char * buffer = vec_buf.data(); if (stdout_file) { - char buffer[128 * 1024]; // large buffer for storing info - while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { + while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); std::string str(buffer); if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index cbc40a35f..d45513dbe 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -144,6 +144,17 @@ json task_params::to_json(bool only_metrics) const { // // task_result_state // +task_result_state::task_result_state(const common_chat_parser_params & chat_parser_params) + : chat_parser_params(chat_parser_params) + , oai_resp_id("resp_" + random_string()) + , oai_resp_reasoning_id("rs_" + random_string()) + , oai_resp_message_id("msg_" + random_string()) { + if (!chat_parser_params.echo) { + // initialize chat_msg to avoid emitting a delta containing the assistant prefill + chat_msg = common_chat_parse("", true, chat_parser_params); + } +} + common_chat_msg task_result_state::update_chat_msg( const std::string & text_added, bool is_partial, @@ -421,6 +432,7 @@ task_params server_task::params_from_json_cmpl( if (data.contains("chat_parser")) { params.chat_parser_params.parser.load(data.at("chat_parser").get()); } + params.chat_parser_params.echo = json_value(data, "echo", false); } { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 64bdecd79..0978bb6ff 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -112,11 +112,7 @@ struct task_result_state { const std::string oai_resp_message_id; std::string oai_resp_fc_id; // function call ID for current args delta - task_result_state(const common_chat_parser_params & chat_parser_params) - : chat_parser_params(chat_parser_params) - , oai_resp_id("resp_" + random_string()) - , oai_resp_reasoning_id("rs_" + random_string()) - , oai_resp_message_id("msg_" + random_string()) {} + task_result_state(const common_chat_parser_params & chat_parser_params); // parse partial tool calls and update the internal state common_chat_msg update_chat_msg( diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a23255078..c82f11794 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -86,7 +86,10 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - common_params_print_info(params); + // router server never loads a model and must not touch the GPU + // skip device enumeration so the CUDA primary context stays uncreated + const bool is_router_server = params.model.path.empty(); + common_params_print_info(params, !is_router_server); // validate batch size for embeddings // embeddings require all tokens to be processed in a single ubatch @@ -126,7 +129,6 @@ int main(int argc, char ** argv) { server_routes routes(params, ctx_server); server_tools tools; - bool is_router_server = params.model.path.empty(); std::optional models_routes{}; if (is_router_server) { // setup server instances manager diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 243e41605..f80e46133 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -158,11 +158,12 @@ def test_chat_template(): @pytest.mark.parametrize("prefill,re_prefill", [ ("Whill", "Whill"), - ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"), + ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Wh\n\nill"), ]) def test_chat_template_assistant_prefill(prefill, re_prefill): global server - server.chat_template = "llama3" + server.jinja = True + server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" server.debug = True # to get the "__verbose" object in the response server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -175,14 +176,15 @@ def test_chat_template_assistant_prefill(prefill, re_prefill): }) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == f" <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}" + assert res.body["__verbose"]["prompt"].endswith(f"<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}") def test_chat_template_continue_final_message_vllm_compat(): """continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic. Both must produce the same prompt.""" global server - server.chat_template = "llama3" + server.jinja = True + server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" server.debug = True server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -197,7 +199,7 @@ def test_chat_template_continue_final_message_vllm_compat(): }) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill" + assert res.body["__verbose"]["prompt"].endswith("<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill") def test_chat_template_continue_final_message_mutual_exclusion(): diff --git a/tools/ui/CMakeLists.txt b/tools/ui/CMakeLists.txt index 9687ca92e..383940cb6 100644 --- a/tools/ui/CMakeLists.txt +++ b/tools/ui/CMakeLists.txt @@ -14,12 +14,7 @@ endif() set(TARGET_SRCS "") set(UI_COMPILE_DEFS "") -# Support both old (LLAMA_BUILD_WEBUI) and new (LLAMA_BUILD_UI) option names -if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI) - if(LLAMA_BUILD_WEBUI AND NOT LLAMA_BUILD_UI) - message(DEPRECATION "LLAMA_BUILD_WEBUI is deprecated, use LLAMA_BUILD_UI instead") - endif() - +if(LLAMA_BUILD_UI) set(PUBLIC_ASSETS index.html bundle.js @@ -125,19 +120,17 @@ if(LLAMA_BUILD_WEBUI OR LLAMA_BUILD_UI) endforeach() list(APPEND UI_COMPILE_DEFS - LLAMA_BUILD_WEBUI # Deprecated: use LLAMA_BUILD_UI LLAMA_BUILD_UI - LLAMA_WEBUI_DEFAULT_ENABLED=1 # Deprecated: use LLAMA_UI_DEFAULT_ENABLED LLAMA_UI_DEFAULT_ENABLED=1 ) message(STATUS "UI: embedded with source: ${UI_SOURCE}") else() message(WARNING "UI: no source available. Neither local build (build/tools/ui/dist/) nor HF Bucket download succeeded.") message(WARNING "UI: building server without embedded UI. Set LLAMA_BUILD_UI=OFF to suppress this warning.") - list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0) + list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0) endif() else() - list(APPEND UI_COMPILE_DEFS LLAMA_WEBUI_DEFAULT_ENABLED=0 LLAMA_UI_DEFAULT_ENABLED=0) + list(APPEND UI_COMPILE_DEFS LLAMA_UI_DEFAULT_ENABLED=0) endif() # Build the static library diff --git a/tools/ui/src/app.css b/tools/ui/src/app.css index 6e29b70a3..d6dc6670c 100644 --- a/tools/ui/src/app.css +++ b/tools/ui/src/app.css @@ -1,4 +1,5 @@ @import 'tailwindcss'; +@source "."; @import 'tw-animate-css'; @@ -39,6 +40,9 @@ --sidebar-ring: oklch(0.708 0 0); --code-background: oklch(0.985 0 0); --code-foreground: oklch(0.145 0 0); + --font-mono: + ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, + 'Liberation Mono', Menlo, monospace; --layer-popover: 1000000; --chat-form-area-height: 8rem; @@ -171,6 +175,10 @@ *::-webkit-scrollbar-thumb:hover { background: hsl(var(--muted-foreground) / 0.5); } + + :where(code, pre, kbd, samp) { + font-family: var(--font-mono); + } } @layer utilities { diff --git a/tools/ui/src/app.d.ts b/tools/ui/src/app.d.ts index f5af7323c..ec65952e9 100644 --- a/tools/ui/src/app.d.ts +++ b/tools/ui/src/app.d.ts @@ -39,6 +39,7 @@ import type { DatabaseMessage, DatabaseMessageExtra, DatabaseMessageExtraAudioFile, + DatabaseMessageExtraVideoFile, DatabaseMessageExtraImageFile, DatabaseMessageExtraTextFile, DatabaseMessageExtraPdfFile, @@ -102,6 +103,7 @@ declare global { DatabaseMessage, DatabaseMessageExtra, DatabaseMessageExtraAudioFile, + DatabaseMessageExtraVideoFile, DatabaseMessageExtraImageFile, DatabaseMessageExtraTextFile, DatabaseMessageExtraPdfFile, diff --git a/tools/ui/src/lib/components/app/badges/BadgesModality.svelte b/tools/ui/src/lib/components/app/badges/BadgesModality.svelte index 841f1dd9f..d87184ea9 100644 --- a/tools/ui/src/lib/components/app/badges/BadgesModality.svelte +++ b/tools/ui/src/lib/components/app/badges/BadgesModality.svelte @@ -1,5 +1,5 @@ {#each modalities as modality (modality)} - {#if modality === ModelModality.VISION || modality === ModelModality.AUDIO} + {#if modality === ModelModality.VISION || modality === ModelModality.AUDIO || modality === ModelModality.VIDEO} - Vision + Vision (Image) + {:else if modality === ModelModality.VIDEO} +