From e2ef8fe42ccef597bfeab901dd6e39589613b71e Mon Sep 17 00:00:00 2001 From: jacekpoplawski <67507230+jacekpoplawski@users.noreply.github.com> Date: Mon, 25 May 2026 07:56:18 +0200 Subject: [PATCH] server: fix checkpoints creation (#22929) * common : add common_chat_split_by_role * cont : fix spans to reach end of message * server: fix checkpoints creation - extract message_spans from chat templates - find the prompt token position before the latest user message - split prompt batching at that position - create a context checkpoint before the latest user input - avoid periodic mid-prompt checkpoints when that position is known - handle multimodal prompts when mapping text/template positions to server prompt tokens - add --checkpoint-min-step to control minimum spacing between checkpoints * cont : clean-up * Support autoparser detection for message barriers * server: fix message span delimiter and update docs --------- Co-authored-by: Alde Rojas Co-authored-by: Georgi Gerganov Co-authored-by: Piotr Wilkin --- common/arg.cpp | 11 +- common/chat-auto-parser-helpers.cpp | 6 +- common/chat-auto-parser.h | 6 + common/chat-diff-analyzer.cpp | 177 +++++++++++++++++++++++++++- common/chat.cpp | 65 ++++++++++ common/chat.h | 16 +++ common/common.cpp | 21 ++++ common/common.h | 3 +- tests/test-chat-auto-parser.cpp | 127 +++++++++++++++++++- tests/test-chat.cpp | 40 ++++++- tools/cli/README.md | 1 - tools/server/README.md | 2 +- tools/server/server-common.cpp | 10 ++ tools/server/server-context.cpp | 135 +++++++++++++++++---- tools/server/server-task.h | 3 + 15 files changed, 586 insertions(+), 37 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 24d9734b9..3df8010a2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1334,12 +1334,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( - {"-cpent", "--checkpoint-every-n-tokens"}, "N", - string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt), + {"-cms", "--checkpoint-min-step"}, "N", + string_format("minimum spacing between context checkpoints in tokens (default: %d, 0 = no minimum)", params.checkpoint_min_step), [](common_params & params, int value) { - params.checkpoint_every_nt = value; + if (value < 0) { + throw std::invalid_argument("checkpoint-min-step must be non-negative"); + } + params.checkpoint_min_step = value; } - ).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + ).set_env("LLAMA_ARG_CHECKPOINT_MIN_SPACING_NT").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-cram", "--cache-ram"}, "N", string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)" diff --git a/common/chat-auto-parser-helpers.cpp b/common/chat-auto-parser-helpers.cpp index 2499464cd..81b17e5e1 100644 --- a/common/chat-auto-parser-helpers.cpp +++ b/common/chat-auto-parser-helpers.cpp @@ -310,6 +310,8 @@ std::vector prune_whitespace_segments(const std::vector & segm namespace autoparser { +static const std::string ERR_TMPL = "#**ERROR**#"; + std::string apply_template(const common_chat_template & tmpl, const template_params & params) { generation_params tmpl_params; tmpl_params.messages = params.messages; @@ -326,7 +328,7 @@ std::string apply_template(const common_chat_template & tmpl, const template_par return common_chat_template_direct_apply(tmpl, tmpl_params); } catch (const std::exception & e) { LOG_DBG("Template application failed: %s\n", e.what()); - return ""; + return ERR_TMPL; } } @@ -347,7 +349,7 @@ std::optional compare_variants( std::string output_B = apply_template(tmpl, params_B); // Check for template application failures - if (output_A.empty() || output_B.empty()) { + if (output_A == ERR_TMPL || output_B == ERR_TMPL) { return std::nullopt; } diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index c680e6868..7858f6572 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -377,6 +377,8 @@ struct analyze_tools : analyze_base { struct autoparser { jinja::caps jinja_caps; + std::string user_start; + std::string assistant_start; analyze_reasoning reasoning; analyze_content content; analyze_tools tools; @@ -387,6 +389,10 @@ struct autoparser { autoparser() = default; + // Find the starting marker for the user message and assistant message + std::string detect_user_start_marker(const common_chat_template & tmpl); + std::string detect_assistant_start_marker(const common_chat_template & tmpl); + // Run full differential analysis on a template void analyze_template(const common_chat_template & tmpl); diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index 9c7c9678a..0875c5347 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -8,6 +8,9 @@ #include "peg-parser.h" #include +#include +#include +#include #define ANSI_RESET "\033[0m" #define ANSI_PURPLE "\033[1m\x1b[38;5;126m" @@ -23,6 +26,7 @@ static const std::string FUN_SECOND = "SSS_SECOND_FUN_S"; static const std::string ARG_FIRST = "AA_ARG_FST_AA"; static const std::string ARG_SECOND = "BB_ARG_SND_BB"; static const std::string USER_MSG = "U_USER_MSG Hello END_U"; +static const std::string USER_MSG_TWO = "V_USER_MSG Hello END_V"; static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A"; static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R"; static const std::string CALL_ID_001 = "call00001"; @@ -71,6 +75,7 @@ static std::vector"); analysis.preserved_tokens.push_back("<|END_OF_TURN_TOKEN|>"); + analysis.user_start = "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>"; LOG_DBG(ANSI_ORANGE "[Patch: Cohere Command R+]\n" ANSI_RESET); } }, @@ -108,7 +113,59 @@ static std::vector void { + if (tmpl.src.find("") != std::string::npos && tmpl.src.find("") != std::string::npos && + tmpl.src.find("") != std::string::npos && tmpl.src.find("") != std::string::npos) { + + analysis.tools.format.mode = tool_format::JSON_NATIVE; + analysis.tools.format.section_start = ""; + analysis.tools.format.section_end = ""; + analysis.tools.format.per_call_start = ""; + analysis.tools.format.per_call_end = ""; + analysis.content.mode = content_mode::PLAIN; + analysis.content.start = ""; + analysis.content.end = ""; + analysis.reasoning.mode = reasoning_mode::TAG_BASED; + analysis.reasoning.start = "\n\n"; + analysis.reasoning.end = ""; + analysis.assistant_start = "Assistant"; + analysis.user_start = "User"; + analysis.preserved_tokens.clear(); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: Nemotron Nano v2]\n" ANSI_RESET); + } + }, + // Fireworks + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("{%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\\n\\n'" + " + message['content'] | trim + '\\n' + system_prompt_suffix + '<|eot_id|>' -%}") != std::string::npos) { + analysis.assistant_start = "<|start_header_id|>assistant<|end_header_id|>"; + analysis.user_start = "<|start_header_id|>user<|end_header_id|>"; + LOG_DBG(ANSI_ORANGE "[Patch: Fireworks v2]\n" ANSI_RESET); + } + }, + // Solar Open + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("<|begin|>assistant<|think|><|end|>") != std::string::npos) { + analysis.assistant_start = "<|begin|>assistant"; + LOG_DBG(ANSI_ORANGE "[Patch: Solar Open]\n" ANSI_RESET); + } + }, + // Apriel 1.6 + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("if not loop.last and '[BEGIN FINAL RESPONSE]' in asst_text") != std::string::npos) { + analysis.user_start = "<|begin_user|>"; + analysis.assistant_start = "<|begin_assistant|>"; + LOG_DBG(ANSI_ORANGE "[Patch: Apriel 1.6]\n" ANSI_RESET); + } + }, + }); // Common JSON structures @@ -166,6 +223,8 @@ void autoparser::analyze_template(const common_chat_template & tmpl) { reasoning = analyze_reasoning(tmpl, jinja_caps.supports_tool_calls); content = analyze_content(tmpl, reasoning); tools = analyze_tools(jinja_caps.supports_tool_calls ? analyze_tools(tmpl, jinja_caps, reasoning) : analyze_tools()); + assistant_start = detect_assistant_start_marker(tmpl); + user_start = detect_user_start_marker(tmpl); collect_preserved_tokens(); for (auto & workaround : workarounds) { @@ -173,6 +232,8 @@ void autoparser::analyze_template(const common_chat_template & tmpl) { } LOG_DBG("\n--- Reasoning & Content Structure ---\n"); + LOG_DBG("user_msg_start: %s\n", user_start.c_str()); + LOG_DBG("assistant_msg_start: %s\n", assistant_start.c_str()); LOG_DBG("reasoning_mode: %s\n", mode_to_str(reasoning.mode).c_str()); LOG_DBG("reasoning_start: '%s'\n", reasoning.start.c_str()); LOG_DBG("reasoning_end: '%s'\n", reasoning.end.c_str()); @@ -245,6 +306,120 @@ void autoparser::collect_preserved_tokens() { add_token(tools.call_id.suffix); } +std::string autoparser::detect_assistant_start_marker(const common_chat_template & tmpl) { + json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } + }; + + json assistant_no_reasoning = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + template_params params; + params.messages = json::array({ user_msg }); + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + tmpl, params, [&](template_params & p) { + p.messages = json::array({ user_msg, assistant_no_reasoning }); + } + ); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed, skipping assistant start detection\n" ANSI_RESET, __func__); + return ""; + } + + auto usermsg = comparison->diff.right; + if (usermsg.find(ASSISTANT_MSG) == std::string::npos) { + LOG_DBG(ANSI_ORANGE "%s: Did not find assistant message in assistant message block, skipping detection\n" ANSI_RESET, __func__); + } + + auto ast_prefix = usermsg.substr(0, usermsg.find(ASSISTANT_MSG)); + if (!reasoning.start.empty() && ast_prefix.find(trim_whitespace(reasoning.start)) != std::string::npos) { + ast_prefix = ast_prefix.substr(0, ast_prefix.find(trim_whitespace(reasoning.start))); + } + if (!reasoning.end.empty() && ast_prefix.find(trim_whitespace(reasoning.end)) != std::string::npos) { + ast_prefix = ast_prefix.substr(0, ast_prefix.find(trim_whitespace(reasoning.end))); + } + return trim_whitespace(ast_prefix); +} + +std::string autoparser::detect_user_start_marker(const common_chat_template & tmpl) { + json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } + }; + + json assistant = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json user_msg_two = json{ + { "role", "user" }, + { "content", USER_MSG_TWO } + }; + + template_params params; + params.messages = json::array({}); + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + tmpl, params, [&](template_params & p) { + p.messages = json::array({ user_msg }); + } + ); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed, unsupported empty messages? trying complex variant\n" ANSI_RESET, __func__); + params.messages = json::array({ user_msg_two, assistant }); + comparison = compare_variants( + tmpl, params, [&](template_params & p) { + p.messages = json::array({ user_msg_two, assistant, user_msg }); + } + ); + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed for reserve variant, aborting\n" ANSI_RESET, __func__); + return ""; + } + } + + auto usermsg = comparison->diff.right; + if (usermsg.find(USER_MSG) == std::string::npos) { + LOG_DBG(ANSI_ORANGE "%s: Did not find user message in user message block, aborting detection\n" ANSI_RESET, __func__); + } + + if (usermsg.find(ASSISTANT_MSG) != std::string::npos) { + usermsg = usermsg.substr(usermsg.find(ASSISTANT_MSG) + ASSISTANT_MSG.size()); + } + + auto candidate = usermsg.substr(0, usermsg.find(USER_MSG)); + auto candidate_split = segmentize_markers(candidate); + std::stringstream result; + bool encountered_marker = false; + for (const auto & mrk : candidate_split) { + std::string lower_mrk = std::string(mrk.value); + std::transform(lower_mrk.begin(), lower_mrk.end(), lower_mrk.begin(), + [](unsigned char c) { return std::tolower(c); }); + // heuristic to weed out potential end markers, but only at the start + if (mrk.type == segment_type::MARKER && !encountered_marker && + (lower_mrk.find("end") != std::string::npos || lower_mrk.find("close") != std::string::npos)) { + continue; + } + if (mrk.type == segment_type::TEXT && !encountered_marker && trim_whitespace(mrk.value).empty()) { + continue; + } + encountered_marker |= mrk.type == segment_type::MARKER; + result << mrk.value; + } + return trim_whitespace(result.str()); +} + analyze_reasoning::analyze_reasoning(const common_chat_template & tmpl, bool supports_tools) : analyze_base(tmpl) { LOG_DBG(ANSI_PURPLE "=== Starting differential analysis ===\n" ANSI_RESET); diff --git a/common/chat.cpp b/common/chat.cpp index 56873e3a1..ef151691c 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -90,6 +90,45 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const return text; } +std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims) { + if (delims.empty() || prompt.empty()) { + return {}; + } + + auto parser = build_peg_parser([&](common_peg_parser_builder & p) { + std::vector all_delims; + std::vector tagged_messages; + + all_delims.reserve(delims.size()); + tagged_messages.reserve(delims.size()); + for (const auto & d : delims) { + all_delims.push_back(d.delimiter); + } + + auto any_delim = p.until_one_of(all_delims); + for (const auto & d : delims) { + tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim)); + } + + return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end(); + }); + + common_peg_parse_context ctx(prompt); + const auto result = parser.parse(ctx); + if (!result.success()) { + return {}; + } + + std::vector spans; + ctx.ast.visit(result, [&](const common_peg_ast_node & node) { + if (!node.tag.empty()) { + spans.push_back({ node.tag, node.start, node.end - node.start }); + } + }); + + return spans; +} + json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { if (!content.empty() && !content_parts.empty()) { throw std::runtime_error("Cannot specify both content and content_parts"); @@ -1042,6 +1081,14 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.prompt = prompt; data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); + data.message_spans = common_chat_split_by_role(prompt, { + { "assistant", "<|start|>assistant" }, + { "user", "<|start|>user" }, + { "system", "<|start|>developer" }, + { "system", "<|start|>system" }, + { "tool", "<|start|>functions" }, + }); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -1181,6 +1228,11 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ data.prompt += data.generation_prompt; } + data.message_spans = common_chat_split_by_role(data.prompt, { + { "user", "<|turn>user\n" }, + { "assistant", "<|turn>model\n" }, + }); + data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; data.thinking_start_tag = "<|channel>thought"; @@ -2393,6 +2445,19 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ struct autoparser::autoparser autoparser; autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); + + std::vector delimiters; + if (!autoparser.assistant_start.empty()) { + delimiters.push_back({ "assistant", autoparser.assistant_start }); + } + if (!autoparser.user_start.empty()) { + delimiters.push_back({ "user", autoparser.user_start }); + } + + if (!delimiters.empty()) { + auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters); + } + auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; if (auto_params.supports_thinking) { auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start); diff --git a/common/chat.h b/common/chat.h index b29c627e6..5659cd42a 100644 --- a/common/chat.h +++ b/common/chat.h @@ -143,6 +143,17 @@ struct common_chat_msg_diff { } }; +struct common_chat_msg_span { + std::string role; + std::size_t pos = 0; + std::size_t len = 0; +}; + +struct common_chat_msg_delimiter { + std::string role; + std::string delimiter; +}; + struct common_chat_tool { std::string name; std::string description; @@ -208,6 +219,7 @@ struct common_chat_params { std::vector preserved_tokens; std::vector additional_stops; std::string parser; + std::vector message_spans; }; // per-message parsing syntax @@ -304,6 +316,7 @@ std::optional common_chat_try_specialized_template( const std::string & src, autoparser::generation_params & params); + // specialized per-task preset struct common_chat_prompt_preset { std::string system; @@ -311,3 +324,6 @@ struct common_chat_prompt_preset { }; common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); + +std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims); + diff --git a/common/common.cpp b/common/common.cpp index d77ddeda1..97daf2817 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -445,6 +445,27 @@ std::string string_strip(const std::string & str) { return str.substr(start, end - start); } +std::string string_lcs(std::string_view a, std::string_view b) { + if (a.empty() || b.empty()) return {}; + + std::vector> dp(a.size() + 1, std::vector(b.size() + 1, 0)); + size_t best_len = 0; + size_t best_end_a = 0; + + for (size_t i = 1; i <= a.size(); ++i) { + for (size_t j = 1; j <= b.size(); ++j) { + if (a[i - 1] == b[j - 1]) { + dp[i][j] = dp[i - 1][j - 1] + 1; + if (dp[i][j] > best_len) { + best_len = dp[i][j]; + best_end_a = i; + } + } + } + } + return std::string(a.substr(best_end_a - best_len, best_len)); +} + std::string string_get_sortable_timestamp() { using clock = std::chrono::system_clock; diff --git a/common/common.h b/common/common.h index b0ad7b2ea..8a0e5eed5 100644 --- a/common/common.h +++ b/common/common.h @@ -594,7 +594,7 @@ struct common_params { bool cache_prompt = true; // whether to enable prompt caching bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot - int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill + int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; @@ -731,6 +731,7 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_lcs(std::string_view a, std::string_view b); std::string string_join(const std::vector & values, const std::string & separator); std::vector string_split(const std::string & str, const std::string & delimiter); diff --git a/tests/test-chat-auto-parser.cpp b/tests/test-chat-auto-parser.cpp index 1d96de718..6f8e95748 100644 --- a/tests/test-chat-auto-parser.cpp +++ b/tests/test-chat-auto-parser.cpp @@ -81,6 +81,8 @@ static void test_normalize_quotes_with_embedded_quotes(testing & t); // TAG_WITH_TAGGED argument parsing tests static void test_tagged_args_with_embedded_quotes(testing & t); +static void test_role_markers_all_templates(testing & t); + int main(int argc, char * argv[]) { testing t(std::cout); t.verbose = true; @@ -103,6 +105,7 @@ int main(int argc, char * argv[]) { t.test("standard_json_tools", test_standard_json_tools_formats); t.test("normalize_quotes_to_json", test_normalize_quotes_to_json); t.test("tagged_args_embedded_quotes", test_tagged_args_with_embedded_quotes); + t.test("role_markers_all_templates", test_role_markers_all_templates); return t.summary(); } @@ -714,7 +717,7 @@ static void test_compare_variants_both_modifiers(testing & t) { static void test_compare_variants_template_failure(testing & t) { // Test with template that causes failure during application (not construction) // We use a valid template syntax but one that will fail during application - common_chat_template tmpl("{{ messages[0]['nonexistent_field'] }}", "", ""); + common_chat_template tmpl("{{ messages.cahoot()[0]['nonexistent_field'] }}", "", ""); template_params params; params.messages = json::array({ @@ -1848,6 +1851,128 @@ static json build_edit_tool() { }); } +// ============================================================================ +// Role marker detection tests for all autoparser-handled templates +// +// Verifies that detect_user_start_marker / detect_assistant_start_marker +// return the correct boundary text between turns for every template that +// falls through to the differential autoparser (i.e. is not handled by a +// dedicated specialized template in common_chat_try_specialized_template). +// +// Markers were deduced manually from the jinja sources in models/templates/. +// ============================================================================ +struct role_marker_case { + std::string template_file; + std::string expected_user_start; + std::string expected_assistant_start; +}; + +static void test_role_markers_all_templates(testing & t) { + // Each entry is { template filename, user_start, assistant_start } as + // produced when rendering the standard chatml-like sequences. The values + // come from reading each jinja template and tracing what text precedes + // a user/assistant message body once the autoparser strips any reasoning + // markers it detected first. + const std::vector cases = { + // ChatML family: <|im_start|>{role} ... <|im_end|> + { "Bielik-11B-v3.0-Instruct.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "HuggingFaceTB-SmolLM3-3B.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "MiMo-VL.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "Qwen3.5-4B.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "Qwen3-Coder.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "Qwen-Qwen2.5-7B-Instruct.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "Qwen-Qwen3-0.6B.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "Qwen-QwQ-32B.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "StepFun3.5-Flash.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + { "stepfun-ai-Step-3.5-Flash.jinja", "<|im_start|>user", "<|im_start|>assistant" }, + + // DeepSeek family + { "deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", "<|User|>", "<|Assistant|>" }, + { "deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", "<|User|>", "<|Assistant|>" }, + { "deepseek-ai-DeepSeek-V3.1.jinja", "<|User|>", "<|Assistant|>" }, + { "llama-cpp-deepseek-r1.jinja", "<|User|>", "<|Assistant|>" }, + + // Llama 3 header family + { "meetkai-functionary-medium-v3.1.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" }, + { "meta-llama-Llama-3.1-8B-Instruct.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" }, + { "meta-llama-Llama-3.2-3B-Instruct.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" }, + { "meta-llama-Llama-3.3-70B-Instruct.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" }, + // fireworks-ai forces a trailing assistant header even without add_generation_prompt, + // so the marker is absorbed into the common suffix and assistant_start is detected as empty. + { "fireworks-ai-llama-3-firefunction-v2.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" }, + + // Phi/GLM/Apriel-style: <|user|> / <|assistant|> + { "microsoft-Phi-3.5-mini-instruct.jinja", "<|user|>", "<|assistant|>" }, + { "GLM-4.6.jinja", "<|user|>", "<|assistant|>" }, + { "unsloth-Apriel-1.5.jinja", "<|user|>", "<|assistant|>" }, + { "GLM-4.7-Flash.jinja", "<|user|>", "<|assistant|>" }, + + // Gemma 2: {user|model} + { "google-gemma-2-2b-it.jinja", "user", "model" }, + + // IBM Granite + { "ibm-granite-granite-3.3-2B-Instruct.jinja", "<|start_of_role|>user<|end_of_role|>", "<|start_of_role|>assistant<|end_of_role|>" }, + { "ibm-granite-granite-4.0.jinja", "<|start_of_role|>user<|end_of_role|>", "<|start_of_role|>assistant<|end_of_role|>" }, + + // Cohere R-series + { "CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja", + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|START_RESPONSE|>" }, + { "CohereForAI-c4ai-command-r-plus-tool_use.jinja", + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" }, + + // Mistral: assistant content follows [/INST] immediately, no header + { "mistralai-Mistral-Nemo-Instruct-2407.jinja", "[INST]", "" }, + { "Mistral-Small-3.2-24B-Instruct-2506.jinja", "[INST]", "" }, + + // Apertus uses <|user_start|> / <|assistant_start|> but the user diff + // carries the preceding <|assistant_end|> from the previous turn. + { "Apertus-8B-Instruct.jinja", "<|user_start|>", "<|assistant_start|>" }, + + // Apriel 1.6 wraps the assistant body with <|begin_assistant|>, but + // <|begin_assistant|> is also the detected reasoning start, so the + // assistant_start is trimmed back to the preceding newline. + { "Apriel-1.6-15b-Thinker-fixed.jinja", "<|begin_user|>", "<|begin_assistant|>" }, + + // ByteDance Seed-OSS: {role} + { "ByteDance-Seed-OSS.jinja", "user", "assistant" }, + + // GigaChat 3.1: {role}<|role_sep|> + { "GigaChat3.1-10B-A1.8B.jinja", "user<|role_sep|>", "assistant<|role_sep|>" }, + + // MiniMax M2: ]~b]{user|ai} + { "MiniMax-M2.jinja", "]~b]user", "]~b]ai" }, + + // Nemotron Nano v2: {User|Assistant}; assistant marker + // is followed by a prefilled block that gets included. + { "NVIDIA-Nemotron-Nano-v2.jinja", "User", "Assistant" }, + + // Reka Edge: "human: " / "assistant: " — but the rendered preamble + // depends on enable_thinking, which currently confuses the user-start + // diff and trims the marker down. Lock in the observed value. + { "Reka-Edge.jinja", "human:", "assistant:" }, + + // RWKV-world chat preset: "User: " / "Assistant: " + { "llama-cpp-rwkv-world.jinja", "User:", "Assistant:" }, + + // Upstage Solar 100B: <|begin|>{role}... but reasoning marker absorbs + // the "<|begin|>assistant" prefix from assistant_start. + { "upstage-Solar-Open-100B.jinja", "<|begin|>user<|content|>", "<|begin|>assistant" }, + }; + + for (const auto & c : cases) { + t.test(c.template_file, [&](testing & t) { + common_chat_template tmpl = load_template(t, "models/templates/" + c.template_file); + struct autoparser ap; + ap.analyze_template(tmpl); + t.assert_equal("user_start", c.expected_user_start, ap.user_start); + t.assert_equal("assistant_start", c.expected_assistant_start, ap.assistant_start); + }); + } +} + // Test that reproduces the Seed-OSS template issue with embedded quotes static void test_tagged_args_with_embedded_quotes(testing & t) { json tools = build_edit_tool(); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index a428ef35c..1a5161cc1 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1548,6 +1548,40 @@ static void test_msgs_oaicompat_json_conversion() { } } +static void test_split_by_role() { + LOG_DBG("%s\n", __func__); + + // Empty inputs + assert_equals(0, common_chat_split_by_role("", {}).size()); + assert_equals(0, common_chat_split_by_role("hello", {}).size()); + assert_equals(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size()); + + // Multi-role conversation, no leading/trailing content + { + const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye"; + const auto splits = common_chat_split_by_role(prompt, { + { "user", "<|user|>" }, + { "assistant", "<|assistant|>" }, + }); + assert_equals(3, splits.size()); + + assert_equals("user", splits[0].role); + assert_equals(0, splits[0].pos); + assert_equals(10, splits[0].len); + assert_equals("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len)); + + assert_equals("assistant", splits[1].role); + assert_equals(10, splits[1].pos); + assert_equals(18, splits[1].len); + assert_equals("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len)); + + assert_equals("user", splits[2].role); + assert_equals(28, splits[2].pos); + assert_equals(11, splits[2].len); + assert_equals("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len)); + } +} + static void test_tools_oaicompat_json_conversion() { LOG_DBG("%s\n", __func__); std::vector tools{ @@ -4338,16 +4372,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Format: [{"name": "func", "arguments": {...}}] { auto tst = peg_tester("models/templates/NVIDIA-Nemotron-Nano-v2.jinja", detailed_debug); - tst.test("[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]") + tst.test("[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]") .tools({ special_function_tool }) .expect(message_assist_call) .run(); // Continuation tests tst.test("world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) .messages({ message_user, message_assist_prefill_content }) .add_generation_prompt(false) .continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT) + .expect_reasoning("I'm thinking") .expect_content("Hello, world!\nWhat's up?") .run(); } @@ -5593,6 +5630,7 @@ int main(int argc, char ** argv) { { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); + test_split_by_role(); test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); diff --git a/tools/cli/README.md b/tools/cli/README.md index bab65d505..add4021e2 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -147,7 +147,6 @@ | `--display-prompt, --no-display-prompt` | whether to print prompt at generation (default: true) | | `-co, --color [on\|off\|auto]` | Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal | | `-ctxcp, --ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 32)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | -| `-cpent, --checkpoint-every-n-tokens N` | create a checkpoint every n tokens during prefill (processing), -1 to disable (default: 8192)
(env: LLAMA_ARG_CHECKPOINT_EVERY_NT) | | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-sys, --system-prompt PROMPT` | system prompt to use with model (if applicable, depending on chat template) | diff --git a/tools/server/README.md b/tools/server/README.md index f2f73f6dc..0b7f9f994 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -163,7 +163,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `-lcs, --lookup-cache-static FNAME` | path to static lookup cache to use for lookup decoding (not updated by generation) | | `-lcd, --lookup-cache-dynamic FNAME` | path to dynamic lookup cache to use for lookup decoding (updated by generation) | | `-ctxcp, --ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 32)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | -| `-cpent, --checkpoint-every-n-tokens N` | create a checkpoint every n tokens during prefill (processing), -1 to disable (default: 8192)
(env: LLAMA_ARG_CHECKPOINT_EVERY_NT) | +| `-cms, --checkpoint-min-step N` | minimum spacing between context checkpoints in tokens (default: 256, 0 = no minimum)
(env: LLAMA_ARG_CHECKPOINT_MIN_SPACING_NT) | | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)
(env: LLAMA_ARG_KV_UNIFIED) | | `--cache-idle-slots, --no-cache-idle-slots` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)
(env: LLAMA_ARG_CACHE_IDLE_SLOTS) | diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index dc00edfa8..fb71792fe 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1110,6 +1110,16 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } + llama_params["message_spans"] = json::array(); + + for (const auto & span : chat_params.message_spans) { + llama_params["message_spans"].push_back({ + { "role", span.role }, + { "pos", span.pos }, + { "len", span.len }, + }); + } + // Reasoning budget: pass parameters through to sampling layer { int reasoning_budget = opt.reasoning_budget; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c3daafd0d..9fecc4247 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1103,6 +1103,13 @@ private: } SRV_INF("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + if (params_base.n_ctx_checkpoints > 0) { + SRV_INF("context checkpoints enabled, max = %d, min spacing = %d\n", + params_base.n_ctx_checkpoints, params_base.checkpoint_min_step); + } else { + SRV_INF("%s", "context checkpoints disabled\n"); + } + if (!params_base.model_alias.empty()) { // backward compat: use first alias as model name model_name = *params_base.model_alias.begin(); @@ -2758,8 +2765,6 @@ private: } if (pos_min >= pos_min_thold) { - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); - // search for a context checkpoint const auto it = std::find_if( slot.prompt.checkpoints.rbegin(), @@ -2776,7 +2781,6 @@ private: if (!do_reset) { // restore the context checkpoint - it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -2912,6 +2916,9 @@ private: has_mtmd = true; } + const int32_t n_before_user = slot.task->params.n_before_user; + const bool n_before_user_known = n_before_user > 0; + // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process @@ -2940,6 +2947,13 @@ private: slot.n_prompt_tokens_processed++; + // stop the prompt batch exactly before the latest user input, so a checkpoint + // can be created after the previous messages + if (n_before_user_known && + slot.prompt.n_tokens() == n_before_user) { + break; + } + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. // create checkpoints that many tokens before the end of the prompt: // - 4 + n_ubatch @@ -2965,6 +2979,8 @@ private: // the number of tokens added to the batch for the current slot const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; @@ -2979,39 +2995,49 @@ private: slot.init_sampler(); } else { - if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) { - // near the end of the prompt - do_checkpoint = do_checkpoint && true; - } else { - // only do non-end checkpoints if the "checkpoint every n tokens" option is set - do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0; - - if (do_checkpoint) { - llama_pos last_checkpoint = 0; - if (!slot.prompt.checkpoints.empty()) { - last_checkpoint = slot.prompt.checkpoints.back().n_tokens; - } - - do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt; - - if (do_checkpoint) { - SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens()); - } - } + // skip ordinary mid-prompt checkpoints + if (!n_before_user_known && !near_prompt_end) { + do_checkpoint = false; } } const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); + // checkpoints are created before the current batch is decoded, so + // their token position is the batch start rather than the prompt end + const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; + + { + const bool is_on_user = + n_before_user_known && + n_tokens_start == n_before_user; + + const bool is_after_user = + n_before_user_known && + n_tokens_start > n_before_user; + + const bool is_allowed = + !n_before_user_known || + is_on_user || + (is_after_user && near_prompt_end); + + if (do_checkpoint && !is_allowed) { + do_checkpoint = false; + } + } + + // nothing to checkpoint yet + // TODO: is this check needed? + if (do_checkpoint && pos_min < 0) { + do_checkpoint = false; + } // do not checkpoint after mtmd chunks do_checkpoint = do_checkpoint && !has_mtmd; // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not @@ -3528,6 +3554,53 @@ void server_context::on_sleeping_changed(std::function callback) { impl->queue_tasks.on_sleeping_state(std::move(callback)); } +// compute the number of tokens before the last user message in the prompt +static int32_t prompt_get_n_before_user( + const json & message_spans, + const std::string & prompt, + const std::vector & files, + const llama_vocab * vocab, + mtmd_context * mctx) { + int32_t result = -1; + int32_t byte_pos = -1; + + for (const auto & span : message_spans) { + const std::string role = json_value(span, "role", std::string()); + + if (role == "user") { + byte_pos = json_value(span, "pos", -1); + } + } + + if (byte_pos >= 0) { + GGML_ASSERT((size_t) byte_pos <= prompt.size()); + + const std::string prefix = prompt.substr(0, (size_t) byte_pos); + + const std::string marker = get_media_marker(); + size_t n_prefix_media = 0; + for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { + n_prefix_media++; + } + + GGML_ASSERT(n_prefix_media <= files.size()); + + if (mctx != nullptr && n_prefix_media > 0) { + // TODO: this makes a copy - avoid it + std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); + + result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); + } else { + result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); + } + + SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", + byte_pos, n_prefix_media, result); + } + + return result; +} + // // server_routes @@ -3577,6 +3650,18 @@ std::unique_ptr server_routes::handle_completions_impl( meta->slot_n_ctx, meta->logit_bias_eog, data); + + const auto message_spans = json_value(data, "message_spans", json::array()); + if (prompt.is_string() && message_spans.is_array()) { + task.params.n_before_user = + prompt_get_n_before_user( + message_spans, + prompt.get(), + files, + ctx_server.vocab, + ctx_server.mctx); + } + task.id_slot = json_value(data, "id_slot", -1); // OAI-compat diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 0978bb6ff..60e216e79 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -61,6 +61,9 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) + // number of prompt tokens before the latest user message + int32_t n_before_user = -1; + int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit