server : refactor "use checkpoint" logic (#22114)

This commit is contained in:
Georgi Gerganov 2026-04-20 08:42:37 +03:00 committed by GitHub
parent 788fcbc5dd
commit de71b5f81c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 93 additions and 92 deletions

View file

@ -78,9 +78,10 @@ enum server_state {
struct server_slot {
int id;
// TODO: change to unique_ptrs for consistency:
llama_context * ctx = nullptr;
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
// multimodal
mtmd_context * mctx = nullptr;
@ -90,7 +91,6 @@ struct server_slot {
server_prompt_checkpoint spec_ckpt;
common_speculative_ptr spec;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
std::unique_ptr<const server_task> task;
@ -343,7 +343,7 @@ struct server_slot {
if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse
if (task->params.speculative.use_checkpoints) {
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
GGML_ASSERT(!spec_ckpt.empty());
}
} else {
@ -362,15 +362,13 @@ struct server_slot {
spec_draft.clear();
}
if (!spec_draft.empty() && params_spec.use_checkpoints) {
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
auto & ckpt = spec_ckpt;
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
}
}
@ -871,14 +869,13 @@ private:
slots.clear();
const auto spec_type = common_speculative_is_compat(ctx);
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
params_base.speculative.use_checkpoints = true;
}
// initialize slots
@ -893,11 +890,13 @@ private:
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.ctx_seq_rm_type = ctx_seq_rm_type;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
if (slot.spec) {
@ -2588,15 +2587,11 @@ private:
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model does not support partial sequence removal
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && (
llama_model_is_recurrent(model) ||
llama_model_is_hybrid(model) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full));
bool has_mtmd = false;
@ -2965,8 +2960,6 @@ private:
// verify and try to accept the draft
{
const auto & params_spec = slot.task->params.speculative;
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
@ -2979,13 +2972,14 @@ private:
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (params_spec.use_checkpoints) {
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);
auto & ckpt = slot.spec_ckpt;
const auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt.size()) {