mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
server : refactor "use checkpoint" logic (#22114)
This commit is contained in:
parent
788fcbc5dd
commit
de71b5f81c
7 changed files with 93 additions and 92 deletions
|
|
@ -78,9 +78,10 @@ enum server_state {
|
|||
struct server_slot {
|
||||
int id;
|
||||
|
||||
// TODO: change to unique_ptrs for consistency:
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
|
||||
|
|
@ -90,7 +91,6 @@ struct server_slot {
|
|||
server_prompt_checkpoint spec_ckpt;
|
||||
common_speculative_ptr spec;
|
||||
|
||||
|
||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||
std::unique_ptr<const server_task> task;
|
||||
|
|
@ -343,7 +343,7 @@ struct server_slot {
|
|||
|
||||
if (!spec_draft.empty()) {
|
||||
// we have a previous (partial) draft to reuse
|
||||
if (task->params.speculative.use_checkpoints) {
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
GGML_ASSERT(!spec_ckpt.empty());
|
||||
}
|
||||
} else {
|
||||
|
|
@ -362,15 +362,13 @@ struct server_slot {
|
|||
spec_draft.clear();
|
||||
}
|
||||
|
||||
if (!spec_draft.empty() && params_spec.use_checkpoints) {
|
||||
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
const auto n_tokens = prompt.tokens.size();
|
||||
|
||||
auto & ckpt = spec_ckpt;
|
||||
|
||||
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
|
||||
|
||||
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
|
||||
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -871,14 +869,13 @@ private:
|
|||
|
||||
slots.clear();
|
||||
|
||||
const auto spec_type = common_speculative_is_compat(ctx);
|
||||
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
|
||||
const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx);
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
SRV_WRN("%s", "speculative decoding not supported by this context\n");
|
||||
}
|
||||
|
||||
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
|
||||
if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
|
||||
params_base.speculative.use_checkpoints = true;
|
||||
}
|
||||
|
||||
// initialize slots
|
||||
|
|
@ -893,11 +890,13 @@ private:
|
|||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
|
||||
slot.ctx_seq_rm_type = ctx_seq_rm_type;
|
||||
|
||||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
// try speculative decoding
|
||||
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
|
||||
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
|
||||
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
|
||||
|
||||
if (slot.spec) {
|
||||
|
|
@ -2588,15 +2587,11 @@ private:
|
|||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model does not support partial sequence removal
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full));
|
||||
|
||||
bool has_mtmd = false;
|
||||
|
||||
|
|
@ -2965,8 +2960,6 @@ private:
|
|||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
const auto & params_spec = slot.task->params.speculative;
|
||||
|
||||
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
|
|
@ -2979,13 +2972,14 @@ private:
|
|||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (params_spec.use_checkpoints) {
|
||||
if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
|
||||
// partial acceptance is not supported by the context -> truncate the draft and restore the state
|
||||
slot.spec_draft = std::move(accepted);
|
||||
|
||||
auto & ckpt = slot.spec_ckpt;
|
||||
const auto & ckpt = slot.spec_ckpt;
|
||||
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n",
|
||||
ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt.size()) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue