server: fix checkpoints creation (#22929)

* common : add common_chat_split_by_role

* cont : fix spans to reach end of message

* server: fix checkpoints creation

- extract message_spans from chat templates
- find the prompt token position before the latest user message
- split prompt batching at that position
- create a context checkpoint before the latest user input
- avoid periodic mid-prompt checkpoints when that position is known
- handle multimodal prompts when mapping text/template positions to server prompt tokens
- add --checkpoint-min-step to control minimum spacing between checkpoints

* cont : clean-up

* Support autoparser detection for message barriers

* server: fix message span delimiter and update docs

---------

Co-authored-by: Alde Rojas <hello@alde.dev>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
This commit is contained in:
jacekpoplawski 2026-05-25 07:56:18 +02:00 committed by GitHub
parent 6d57c26ef8
commit e2ef8fe42c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 586 additions and 37 deletions

View file

@ -1103,6 +1103,13 @@ private:
}
SRV_INF("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
if (params_base.n_ctx_checkpoints > 0) {
SRV_INF("context checkpoints enabled, max = %d, min spacing = %d\n",
params_base.n_ctx_checkpoints, params_base.checkpoint_min_step);
} else {
SRV_INF("%s", "context checkpoints disabled\n");
}
if (!params_base.model_alias.empty()) {
// backward compat: use first alias as model name
model_name = *params_base.model_alias.begin();
@ -2758,8 +2765,6 @@ private:
}
if (pos_min >= pos_min_thold) {
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
@ -2776,7 +2781,6 @@ private:
if (!do_reset) {
// restore the context checkpoint
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
@ -2912,6 +2916,9 @@ private:
has_mtmd = true;
}
const int32_t n_before_user = slot.task->params.n_before_user;
const bool n_before_user_known = n_before_user > 0;
// add prompt tokens for processing in the current batch
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
// get next token to process
@ -2940,6 +2947,13 @@ private:
slot.n_prompt_tokens_processed++;
// stop the prompt batch exactly before the latest user input, so a checkpoint
// can be created after the previous messages
if (n_before_user_known &&
slot.prompt.n_tokens() == n_before_user) {
break;
}
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
// create checkpoints that many tokens before the end of the prompt:
// - 4 + n_ubatch
@ -2965,6 +2979,8 @@ private:
// the number of tokens added to the batch for the current slot
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@ -2979,39 +2995,49 @@ private:
slot.init_sampler();
} else {
if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
// near the end of the prompt
do_checkpoint = do_checkpoint && true;
} else {
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
if (do_checkpoint) {
llama_pos last_checkpoint = 0;
if (!slot.prompt.checkpoints.empty()) {
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
}
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
if (do_checkpoint) {
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
}
}
// skip ordinary mid-prompt checkpoints
if (!n_before_user_known && !near_prompt_end) {
do_checkpoint = false;
}
}
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
// checkpoints are created before the current batch is decoded, so
// their token position is the batch start rather than the prompt end
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
{
const bool is_on_user =
n_before_user_known &&
n_tokens_start == n_before_user;
const bool is_after_user =
n_before_user_known &&
n_tokens_start > n_before_user;
const bool is_allowed =
!n_before_user_known ||
is_on_user ||
(is_after_user && near_prompt_end);
if (do_checkpoint && !is_allowed) {
do_checkpoint = false;
}
}
// nothing to checkpoint yet
// TODO: is this check needed?
if (do_checkpoint && pos_min < 0) {
do_checkpoint = false;
}
// do not checkpoint after mtmd chunks
do_checkpoint = do_checkpoint && !has_mtmd;
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
@ -3528,6 +3554,53 @@ void server_context::on_sleeping_changed(std::function<void(bool)> callback) {
impl->queue_tasks.on_sleeping_state(std::move(callback));
}
// compute the number of tokens before the last user message in the prompt
static int32_t prompt_get_n_before_user(
const json & message_spans,
const std::string & prompt,
const std::vector<raw_buffer> & files,
const llama_vocab * vocab,
mtmd_context * mctx) {
int32_t result = -1;
int32_t byte_pos = -1;
for (const auto & span : message_spans) {
const std::string role = json_value(span, "role", std::string());
if (role == "user") {
byte_pos = json_value(span, "pos", -1);
}
}
if (byte_pos >= 0) {
GGML_ASSERT((size_t) byte_pos <= prompt.size());
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
const std::string marker = get_media_marker();
size_t n_prefix_media = 0;
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
n_prefix_media++;
}
GGML_ASSERT(n_prefix_media <= files.size());
if (mctx != nullptr && n_prefix_media > 0) {
// TODO: this makes a copy - avoid it
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
} else {
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
}
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
byte_pos, n_prefix_media, result);
}
return result;
}
//
// server_routes
@ -3577,6 +3650,18 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
meta->slot_n_ctx,
meta->logit_bias_eog,
data);
const auto message_spans = json_value(data, "message_spans", json::array());
if (prompt.is_string() && message_spans.is_array()) {
task.params.n_before_user =
prompt_get_n_before_user(
message_spans,
prompt.get<std::string>(),
files,
ctx_server.vocab,
ctx_server.mctx);
}
task.id_slot = json_value(data, "id_slot", -1);
// OAI-compat