From de71b5f81c3b6b9f8bdaf1b2a21198e1eede3fda Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 20 Apr 2026 08:42:37 +0300 Subject: [PATCH] server : refactor "use checkpoint" logic (#22114) --- common/arg.cpp | 2 +- common/common.cpp | 38 +++++++++++++++++++- common/common.h | 25 ++++++++++--- common/hf-cache.cpp | 4 +-- common/speculative.cpp | 62 ++++++++------------------------- common/speculative.h | 10 ------ tools/server/server-context.cpp | 44 ++++++++++------------- 7 files changed, 93 insertions(+), 92 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 43fe5a25d..099f0aeab 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa hf_tag = "default"; } - std::string model_endpoint = get_model_endpoint(); + std::string model_endpoint = common_get_model_endpoint(); auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini"; // prepare local path for caching diff --git a/common/common.cpp b/common/common.cpp index d3f1cee39..6cde71d81 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { common_init_result::~common_init_result() = default; -std::string get_model_endpoint() { +std::string common_get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. const char * hf_endpoint_env = getenv("HF_ENDPOINT"); @@ -1397,6 +1397,42 @@ std::string get_model_endpoint() { return model_endpoint; } +common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { + auto * mem = llama_get_memory(ctx); + if (mem == nullptr) { + return COMMON_CONTEXT_SEQ_RM_TYPE_NO; + } + + common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART; + + llama_memory_clear(mem, true); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size())); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + goto done; + } + + // try to remove the last tokens + if (!llama_memory_seq_rm(mem, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + goto done; + } + +done: + llama_memory_clear(mem, true); + llama_synchronize(ctx); + + return res; +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { std::vector loras; std::vector scales; diff --git a/common/common.h b/common/common.h index 027339294..cbcb3bb65 100644 --- a/common/common.h +++ b/common/common.h @@ -308,10 +308,9 @@ struct common_params_speculative { // ngram-based speculative decoding - uint16_t ngram_size_n = 12; // ngram size for lookup - uint16_t ngram_size_m = 48; // mgram size for speculative tokens - uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed - bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models + uint16_t ngram_size_n = 12; // ngram size for lookup + uint16_t ngram_size_m = 48; // mgram size for speculative tokens + uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed std::shared_ptr ngram_mod; @@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); -std::string get_model_endpoint(); +// model endpoint from env +std::string common_get_model_endpoint(); + +// +// Context utils +// + +enum common_context_seq_rm_type { + COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) + COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences + COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only +}; + +// check if the llama_context can remove sequences +// note: clears the memory of the context +common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx); + // // Batch utils diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index 38a4c17a9..ea5b2150d 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url, static std::string get_repo_commit(const std::string & repo_id, const std::string & token) { try { - auto endpoint = get_model_endpoint(); + auto endpoint = common_get_model_endpoint(); auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); if (!json.is_object() || @@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id, hf_files files; try { - auto endpoint = get_model_endpoint(); + auto endpoint = common_get_model_endpoint(); auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); if (!json.is_array()) { diff --git a/common/speculative.cpp b/common/speculative.cpp index 1789560ee..daa2b5a8a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; + bool use_ckpt = false; struct common_speculative_checkpoint ckpt; - bool use_checkpoint; common_sampler * smpl; @@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt, llama_context * ctx_dft, const std::vector> & replacements, - bool use_checkpoint) + bool use_ckpt) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) - , use_checkpoint(use_checkpoint) + , use_ckpt(use_ckpt) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state { } void begin(const llama_tokens & prompt) override { - if (use_checkpoint && ckpt.size() > 0) { + if (use_ckpt && ckpt.size() > 0) { // delete checkpoint LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); @@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state { LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) { + if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) { LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", __func__, reuse_i, reuse_n); reuse_i = 0; @@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state { result.clear(); result.reserve(params.n_max); - bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0; - if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) { + bool needs_ckpt = use_ckpt && prompt_dft.size() > 0; + if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); } else { @@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state { } if (reuse_n < (int) prompt_dft.size() || do_restore) { - if (use_checkpoint) { + if (use_ckpt) { if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); @@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } -common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) { - auto * mem = llama_get_memory(ctx_tgt); - if (mem == nullptr) { - return COMMON_SPECULATIVE_COMPAT_TYPE_NO; - } - - common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL; - - llama_memory_clear(mem, true); - - // eval 2 tokens to check if the context is compatible - std::vector tmp; - tmp.push_back(0); - tmp.push_back(0); - - int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); - if (ret != 0) { - LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = COMMON_SPECULATIVE_COMPAT_TYPE_NO; - goto done; - } - - // try to remove the last tokens - if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT; - goto done; - } - -done: - llama_memory_clear(mem, true); - llama_synchronize(ctx_tgt); - - return res; -} - // initialization of the speculative decoding system // common_speculative * common_speculative_init( @@ -1022,11 +986,13 @@ common_speculative * common_speculative_init( case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { + const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements, - /* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model! + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements, + /* .use_ckpt = */ use_ckpt )); break; } diff --git a/common/speculative.h b/common/speculative.h index cbe6e5bdb..bca78d32b 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -enum common_speculative_compat_type { - COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0, - COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1, - COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2, -}; - -// check if the llama_context is compatible for speculative decoding -// note: clears the memory of the context -common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt); - common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7ffe6a303..99856e6c3 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -78,9 +78,10 @@ enum server_state { struct server_slot { int id; - // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; + common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + // multimodal mtmd_context * mctx = nullptr; @@ -90,7 +91,6 @@ struct server_slot { server_prompt_checkpoint spec_ckpt; common_speculative_ptr spec; - // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; @@ -343,7 +343,7 @@ struct server_slot { if (!spec_draft.empty()) { // we have a previous (partial) draft to reuse - if (task->params.speculative.use_checkpoints) { + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { GGML_ASSERT(!spec_ckpt.empty()); } } else { @@ -362,15 +362,13 @@ struct server_slot { spec_draft.clear(); } - if (!spec_draft.empty() && params_spec.use_checkpoints) { + if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { const auto n_tokens = prompt.tokens.size(); - auto & ckpt = spec_ckpt; - - ckpt = server_get_checkpoint(ctx, this->id, n_tokens); + spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens); SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024); + spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); } } @@ -871,14 +869,13 @@ private: slots.clear(); - const auto spec_type = common_speculative_is_compat(ctx); - if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) { + const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) { + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); - params_base.speculative.use_checkpoints = true; } // initialize slots @@ -893,11 +890,13 @@ private: slot.ctx = ctx; slot.n_ctx = n_ctx_slot; + slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) { + if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); if (slot.spec) { @@ -2588,15 +2587,11 @@ private: // make a checkpoint of the parts of the memory that cannot be rolled back. // checkpoints are created only if: + // - the model does not support partial sequence removal // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); + (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full)); bool has_mtmd = false; @@ -2965,8 +2960,6 @@ private: // verify and try to accept the draft { - const auto & params_spec = slot.task->params.speculative; - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); @@ -2979,13 +2972,14 @@ private: // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (params_spec.use_checkpoints) { + if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); - auto & ckpt = slot.spec_ckpt; + const auto & ckpt = slot.spec_ckpt; - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", + ckpt.pos_min, ckpt.pos_max, ckpt.size()); const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != ckpt.size()) {