mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-26 10:41:25 +00:00
server : refactor "use checkpoint" logic (#22114)
This commit is contained in:
parent
788fcbc5dd
commit
de71b5f81c
7 changed files with 93 additions and 92 deletions
|
|
@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
|
|||
hf_tag = "default";
|
||||
}
|
||||
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
std::string model_endpoint = common_get_model_endpoint();
|
||||
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
|
||||
|
||||
// prepare local path for caching
|
||||
|
|
|
|||
|
|
@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
|||
|
||||
common_init_result::~common_init_result() = default;
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
std::string common_get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
|
||||
|
|
@ -1397,6 +1397,42 @@ std::string get_model_endpoint() {
|
|||
return model_endpoint;
|
||||
}
|
||||
|
||||
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
if (mem == nullptr) {
|
||||
return COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
}
|
||||
|
||||
common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART;
|
||||
|
||||
llama_memory_clear(mem, true);
|
||||
|
||||
// eval 2 tokens to check if the context is compatible
|
||||
std::vector<llama_token> tmp;
|
||||
tmp.push_back(0);
|
||||
tmp.push_back(0);
|
||||
|
||||
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
goto done;
|
||||
}
|
||||
|
||||
done:
|
||||
llama_memory_clear(mem, true);
|
||||
llama_synchronize(ctx);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
std::vector<llama_adapter_lora *> loras;
|
||||
std::vector<float> scales;
|
||||
|
|
|
|||
|
|
@ -308,10 +308,9 @@ struct common_params_speculative {
|
|||
|
||||
// ngram-based speculative decoding
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||
|
||||
|
|
@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
|
|||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||
|
||||
std::string get_model_endpoint();
|
||||
// model endpoint from env
|
||||
std::string common_get_model_endpoint();
|
||||
|
||||
//
|
||||
// Context utils
|
||||
//
|
||||
|
||||
enum common_context_seq_rm_type {
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
|
||||
};
|
||||
|
||||
// check if the llama_context can remove sequences
|
||||
// note: clears the memory of the context
|
||||
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
|
||||
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url,
|
|||
static std::string get_repo_commit(const std::string & repo_id,
|
||||
const std::string & token) {
|
||||
try {
|
||||
auto endpoint = get_model_endpoint();
|
||||
auto endpoint = common_get_model_endpoint();
|
||||
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
|
||||
|
||||
if (!json.is_object() ||
|
||||
|
|
@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id,
|
|||
hf_files files;
|
||||
|
||||
try {
|
||||
auto endpoint = get_model_endpoint();
|
||||
auto endpoint = common_get_model_endpoint();
|
||||
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
|
||||
|
||||
if (!json.is_array()) {
|
||||
|
|
|
|||
|
|
@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
|
||||
bool use_ckpt = false;
|
||||
struct common_speculative_checkpoint ckpt;
|
||||
bool use_checkpoint;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
|
|
@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
const std::vector<std::pair<std::string, std::string>> & replacements,
|
||||
bool use_checkpoint)
|
||||
bool use_ckpt)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, use_checkpoint(use_checkpoint)
|
||||
, use_ckpt(use_ckpt)
|
||||
{
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
smpl = nullptr;
|
||||
|
|
@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
if (use_checkpoint && ckpt.size() > 0) {
|
||||
if (use_ckpt && ckpt.size() > 0) {
|
||||
// delete checkpoint
|
||||
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
|
||||
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
|
|
@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||
if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
|
||||
__func__, reuse_i, reuse_n);
|
||||
reuse_i = 0;
|
||||
|
|
@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
result.clear();
|
||||
result.reserve(params.n_max);
|
||||
|
||||
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
|
||||
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
|
||||
llama_memory_clear(mem_dft, false);
|
||||
prompt_dft.clear();
|
||||
} else {
|
||||
|
|
@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
}
|
||||
|
||||
if (reuse_n < (int) prompt_dft.size() || do_restore) {
|
||||
if (use_checkpoint) {
|
||||
if (use_ckpt) {
|
||||
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
|
||||
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
|
||||
|
|
@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
|||
return it->second;
|
||||
}
|
||||
|
||||
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
|
||||
auto * mem = llama_get_memory(ctx_tgt);
|
||||
if (mem == nullptr) {
|
||||
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
|
||||
}
|
||||
|
||||
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
|
||||
|
||||
llama_memory_clear(mem, true);
|
||||
|
||||
// eval 2 tokens to check if the context is compatible
|
||||
std::vector<llama_token> tmp;
|
||||
tmp.push_back(0);
|
||||
tmp.push_back(0);
|
||||
|
||||
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
|
||||
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
|
||||
goto done;
|
||||
}
|
||||
|
||||
done:
|
||||
llama_memory_clear(mem, true);
|
||||
llama_synchronize(ctx_tgt);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// initialization of the speculative decoding system
|
||||
//
|
||||
common_speculative * common_speculative_init(
|
||||
|
|
@ -1022,11 +986,13 @@ common_speculative * common_speculative_init(
|
|||
case COMMON_SPECULATIVE_TYPE_NONE:
|
||||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
||||
const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements,
|
||||
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements,
|
||||
/* .use_ckpt = */ use_ckpt
|
||||
));
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
|
|||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
enum common_speculative_compat_type {
|
||||
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
|
||||
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
|
||||
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
|
||||
};
|
||||
|
||||
// check if the llama_context is compatible for speculative decoding
|
||||
// note: clears the memory of the context
|
||||
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);
|
||||
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
|
|
|
|||
|
|
@ -78,9 +78,10 @@ enum server_state {
|
|||
struct server_slot {
|
||||
int id;
|
||||
|
||||
// TODO: change to unique_ptrs for consistency:
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
|
||||
|
|
@ -90,7 +91,6 @@ struct server_slot {
|
|||
server_prompt_checkpoint spec_ckpt;
|
||||
common_speculative_ptr spec;
|
||||
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||
std::unique_ptr<const server_task> task;
|
||||
|
|
@ -343,7 +343,7 @@ struct server_slot {
|
|||
|
||||
if (!spec_draft.empty()) {
|
||||
// we have a previous (partial) draft to reuse
|
||||
if (task->params.speculative.use_checkpoints) {
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
GGML_ASSERT(!spec_ckpt.empty());
|
||||
}
|
||||
} else {
|
||||
|
|
@ -362,15 +362,13 @@ struct server_slot {
|
|||
spec_draft.clear();
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && params_spec.use_checkpoints) {
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
auto & ckpt = spec_ckpt;
|
||||
|
||||
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -871,14 +869,13 @@ private:
|
|||
|
||||
slots.clear();
|
||||
|
||||
const auto spec_type = common_speculative_is_compat(ctx);
|
||||
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
|
||||
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
|
||||
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
|
||||
params_base.speculative.use_checkpoints = true;
|
||||
}
|
||||
|
||||
// initialize slots
|
||||
|
|
@ -893,11 +890,13 @@ private:
|
|||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
|
||||
slot.ctx_seq_rm_type = ctx_seq_rm_type;
|
||||
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
// try speculative decoding
|
||||
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
|
||||
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
|
||||
|
||||
if (slot.spec) {
|
||||
|
|
@ -2588,15 +2587,11 @@ private:
|
|||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model does not support partial sequence removal
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full));
|
||||
|
||||
bool has_mtmd = false;
|
||||
|
||||
|
|
@ -2965,8 +2960,6 @@ private:
|
|||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
const auto & params_spec = slot.task->params.speculative;
|
||||
|
||||
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
|
|
@ -2979,13 +2972,14 @@ private:
|
|||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (params_spec.use_checkpoints) {
|
||||
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
auto & ckpt = slot.spec_ckpt;
|
||||
const auto & ckpt = slot.spec_ckpt;
|
||||
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
|
||||
ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt.size()) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue