server : refactor "use checkpoint" logic (#22114)

This commit is contained in:
Georgi Gerganov 2026-04-20 08:42:37 +03:00 committed by GitHub
parent 788fcbc5dd
commit de71b5f81c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 93 additions and 92 deletions

View file

@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
hf_tag = "default"; hf_tag = "default";
} }
std::string model_endpoint = get_model_endpoint(); std::string model_endpoint = common_get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini"; auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching // prepare local path for caching

View file

@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result::~common_init_result() = default; common_init_result::~common_init_result() = default;
std::string get_model_endpoint() { std::string common_get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
const char * hf_endpoint_env = getenv("HF_ENDPOINT"); const char * hf_endpoint_env = getenv("HF_ENDPOINT");
@ -1397,6 +1397,42 @@ std::string get_model_endpoint() {
return model_endpoint; return model_endpoint;
} }
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
auto * mem = llama_get_memory(ctx);
if (mem == nullptr) {
return COMMON_CONTEXT_SEQ_RM_TYPE_NO;
}
common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx);
return res;
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) { void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
std::vector<llama_adapter_lora *> loras; std::vector<llama_adapter_lora *> loras;
std::vector<float> scales; std::vector<float> scales;

View file

@ -308,10 +308,9 @@ struct common_params_speculative {
// ngram-based speculative decoding // ngram-based speculative decoding
uint16_t ngram_size_n = 12; // ngram size for lookup uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
std::shared_ptr<common_ngram_mod> ngram_mod; std::shared_ptr<common_ngram_mod> ngram_mod;
@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora); void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
std::string get_model_endpoint(); // model endpoint from env
std::string common_get_model_endpoint();
//
// Context utils
//
enum common_context_seq_rm_type {
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
};
// check if the llama_context can remove sequences
// note: clears the memory of the context
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
// //
// Batch utils // Batch utils

View file

@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url,
static std::string get_repo_commit(const std::string & repo_id, static std::string get_repo_commit(const std::string & repo_id,
const std::string & token) { const std::string & token) {
try { try {
auto endpoint = get_model_endpoint(); auto endpoint = common_get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
if (!json.is_object() || if (!json.is_object() ||
@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id,
hf_files files; hf_files files;
try { try {
auto endpoint = get_model_endpoint(); auto endpoint = common_get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
if (!json.is_array()) { if (!json.is_array()) {

View file

@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft; llama_context * ctx_dft;
bool use_ckpt = false;
struct common_speculative_checkpoint ckpt; struct common_speculative_checkpoint ckpt;
bool use_checkpoint;
common_sampler * smpl; common_sampler * smpl;
@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt, llama_context * ctx_tgt,
llama_context * ctx_dft, llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements, const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_checkpoint) bool use_ckpt)
: common_speculative_state(type) : common_speculative_state(type)
, ctx_tgt(ctx_tgt) , ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft) , ctx_dft(ctx_dft)
, use_checkpoint(use_checkpoint) , use_ckpt(use_ckpt)
{ {
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr; smpl = nullptr;
@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
} }
void begin(const llama_tokens & prompt) override { void begin(const llama_tokens & prompt) override {
if (use_checkpoint && ckpt.size() > 0) { if (use_ckpt && ckpt.size() > 0) {
// delete checkpoint // delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) { if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n); __func__, reuse_i, reuse_n);
reuse_i = 0; reuse_i = 0;
@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
result.clear(); result.clear();
result.reserve(params.n_max); result.reserve(params.n_max);
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0; bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) { if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false); llama_memory_clear(mem_dft, false);
prompt_dft.clear(); prompt_dft.clear();
} else { } else {
@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
} }
if (reuse_n < (int) prompt_dft.size() || do_restore) { if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (use_checkpoint) { if (use_ckpt) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second; return it->second;
} }
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
}
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx_tgt);
return res;
}
// initialization of the speculative decoding system // initialization of the speculative decoding system
// //
common_speculative * common_speculative_init( common_speculative * common_speculative_init(
@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NONE: case COMMON_SPECULATIVE_TYPE_NONE:
break; break;
case COMMON_SPECULATIVE_TYPE_DRAFT: { case COMMON_SPECULATIVE_TYPE_DRAFT: {
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type, impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt, /* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft, /* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements, /* .replacements = */ params.replacements,
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model! /* .use_ckpt = */ use_ckpt
)); ));
break; break;
} }

View file

@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string // convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type); std::string common_speculative_type_to_str(enum common_speculative_type type);
enum common_speculative_compat_type {
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
};
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init( common_speculative * common_speculative_init(
common_params_speculative & params, common_params_speculative & params,
llama_context * ctx_tgt); llama_context * ctx_tgt);

View file

@ -78,9 +78,10 @@ enum server_state {
struct server_slot { struct server_slot {
int id; int id;
// TODO: change to unique_ptrs for consistency:
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
// multimodal // multimodal
mtmd_context * mctx = nullptr; mtmd_context * mctx = nullptr;
@ -90,7 +91,6 @@ struct server_slot {
server_prompt_checkpoint spec_ckpt; server_prompt_checkpoint spec_ckpt;
common_speculative_ptr spec; common_speculative_ptr spec;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
std::unique_ptr<const server_task> task; std::unique_ptr<const server_task> task;
@ -343,7 +343,7 @@ struct server_slot {
if (!spec_draft.empty()) { if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse // we have a previous (partial) draft to reuse
if (task->params.speculative.use_checkpoints) { if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
GGML_ASSERT(!spec_ckpt.empty()); GGML_ASSERT(!spec_ckpt.empty());
} }
} else { } else {
@ -362,15 +362,13 @@ struct server_slot {
spec_draft.clear(); spec_draft.clear();
} }
if (!spec_draft.empty() && params_spec.use_checkpoints) { if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size(); const auto n_tokens = prompt.tokens.size();
auto & ckpt = spec_ckpt; spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024); spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
} }
} }
@ -871,14 +869,13 @@ private:
slots.clear(); slots.clear();
const auto spec_type = common_speculative_is_compat(ctx); const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) { if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n"); SRV_WRN("%s", "speculative decoding not supported by this context\n");
} }
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) { if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n"); SRV_WRN("%s", "speculative decoding will use checkpoints\n");
params_base.speculative.use_checkpoints = true;
} }
// initialize slots // initialize slots
@ -893,11 +890,13 @@ private:
slot.ctx = ctx; slot.ctx = ctx;
slot.n_ctx = n_ctx_slot; slot.n_ctx = n_ctx_slot;
slot.ctx_seq_rm_type = ctx_seq_rm_type;
slot.mctx = mctx; slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr; slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding // try speculative decoding
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) { if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
if (slot.spec) { if (slot.spec) {
@ -2588,15 +2587,11 @@ private:
// make a checkpoint of the parts of the memory that cannot be rolled back. // make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if: // checkpoints are created only if:
// - the model does not support partial sequence removal
// - the model uses SWA and we are not using `swa_full` // - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && ( do_checkpoint = do_checkpoint && (
llama_model_is_recurrent(model) || (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
llama_model_is_hybrid(model) || (llama_model_n_swa(model) > 0 && !params_base.swa_full));
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);
bool has_mtmd = false; bool has_mtmd = false;
@ -2965,8 +2960,6 @@ private:
// verify and try to accept the draft // verify and try to accept the draft
{ {
const auto & params_spec = slot.task->params.speculative;
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
@ -2979,13 +2972,14 @@ private:
// check for partial draft acceptance // check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) { if (accepted.size() < slot.spec_draft.size() + 1) {
if (params_spec.use_checkpoints) { if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
// partial acceptance is not supported by the context -> truncate the draft and restore the state // partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted); slot.spec_draft = std::move(accepted);
auto & ckpt = slot.spec_ckpt; const auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt.size()) { if (n != ckpt.size()) {