server : refactor "use checkpoint" logic (#22114)

This commit is contained in:
Georgi Gerganov 2026-04-20 08:42:37 +03:00 committed by GitHub
parent 788fcbc5dd
commit de71b5f81c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 93 additions and 92 deletions

View file

@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
hf_tag = "default";
}
std::string model_endpoint = get_model_endpoint();
std::string model_endpoint = common_get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching

View file

@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result::~common_init_result() = default;
std::string get_model_endpoint() {
std::string common_get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
@ -1397,6 +1397,42 @@ std::string get_model_endpoint() {
return model_endpoint;
}
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
auto * mem = llama_get_memory(ctx);
if (mem == nullptr) {
return COMMON_CONTEXT_SEQ_RM_TYPE_NO;
}
common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx);
return res;
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;

View file

@ -308,10 +308,9 @@ struct common_params_speculative {
// ngram-based speculative decoding
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::shared_ptr<common_ngram_mod> ngram_mod;
@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
std::string get_model_endpoint();
// model endpoint from env
std::string common_get_model_endpoint();
//
// Context utils
//
enum common_context_seq_rm_type {
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
};
// check if the llama_context can remove sequences
// note: clears the memory of the context
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
//
// Batch utils

View file

@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url,
static std::string get_repo_commit(const std::string & repo_id,
const std::string & token) {
try {
auto endpoint = get_model_endpoint();
auto endpoint = common_get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
if (!json.is_object() ||
@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id,
hf_files files;
try {
auto endpoint = get_model_endpoint();
auto endpoint = common_get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
if (!json.is_array()) {

View file

@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
bool use_ckpt = false;
struct common_speculative_checkpoint ckpt;
bool use_checkpoint;
common_sampler * smpl;
@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_checkpoint)
bool use_ckpt)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_checkpoint(use_checkpoint)
, use_ckpt(use_ckpt)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}
void begin(const llama_tokens & prompt) override {
if (use_checkpoint && ckpt.size() > 0) {
if (use_ckpt && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n);
reuse_i = 0;
@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
result.clear();
result.reserve(params.n_max);
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}
if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (use_checkpoint) {
if (use_ckpt) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
}
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx_tgt);
return res;
}
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_ckpt = */ use_ckpt
));
break;
}

View file

@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
enum common_speculative_compat_type {
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
};
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);

View file

@ -78,9 +78,10 @@ enum server_state {
struct server_slot {
int id;
// TODO: change to unique_ptrs for consistency:
llama_context * ctx = nullptr;
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
// multimodal
mtmd_context * mctx = nullptr;
@ -90,7 +91,6 @@ struct server_slot {
server_prompt_checkpoint spec_ckpt;
common_speculative_ptr spec;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
std::unique_ptr<const server_task> task;
@ -343,7 +343,7 @@ struct server_slot {
if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse
if (task->params.speculative.use_checkpoints) {
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
GGML_ASSERT(!spec_ckpt.empty());
}
} else {
@ -362,15 +362,13 @@ struct server_slot {
spec_draft.clear();
}
if (!spec_draft.empty() && params_spec.use_checkpoints) {
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
auto & ckpt = spec_ckpt;
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
}
}
@ -871,14 +869,13 @@ private:
slots.clear();
const auto spec_type = common_speculative_is_compat(ctx);
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
params_base.speculative.use_checkpoints = true;
}
// initialize slots
@ -893,11 +890,13 @@ private:
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.ctx_seq_rm_type = ctx_seq_rm_type;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
if (slot.spec) {
@ -2588,15 +2587,11 @@ private:
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model does not support partial sequence removal
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && (
llama_model_is_recurrent(model) ||
llama_model_is_hybrid(model) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full));
bool has_mtmd = false;
@ -2965,8 +2960,6 @@ private:
// verify and try to accept the draft
{
const auto & params_spec = slot.task->params.speculative;
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
@ -2979,13 +2972,14 @@ private:
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (params_spec.use_checkpoints) {
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);
auto & ckpt = slot.spec_ckpt;
const auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt.size()) {