Checkpoint every n tokens: squash (#20087)

This commit is contained in:
Piotr Wilkin (ilintar) 2026-03-06 11:39:26 +01:00 committed by GitHub
parent f6235a41ef
commit f5ddcd1696
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 77 additions and 48 deletions

View file

@ -12,6 +12,7 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include <algorithm>
#include <cstddef>
#include <cinttypes>
#include <memory>
@ -2348,8 +2349,10 @@ private:
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
[&](const auto & cur) {
[&, func_name = __func__](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
return cur.pos_min < pos_min_thold;
}
);
@ -2533,48 +2536,66 @@ private:
slot.i_batch = batch.n_tokens - 1;
slot.init_sampler();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
} else {
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
if (do_checkpoint) {
llama_pos last_checkpoint = 0;
if (!slot.prompt.checkpoints.empty()) {
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
}
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
if (do_checkpoint) {
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
}
}
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot,
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size =
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
}
if (!slot_batched) {