diff --git a/common/arg.cpp b/common/arg.cpp index 55ec9389b..9fefe411e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } - for (auto & pair : params.speculative.draft.replacements) { - string_process_escapes(pair.first); - string_process_escapes(pair.second); - } } if (!params.kv_overrides.empty()) { @@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.p_min = std::stof(value); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN")); - add_opt(common_arg( - {"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N", - string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx), - [](common_params & params, int value) { - params.speculative.draft.n_ctx = value; - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE")); add_opt(common_arg( {"--spec-draft-device", "-devd", "--device-draft"}, "", "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" @@ -3561,32 +3550,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); add_opt(common_arg( - {"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT", - "translate the string in TARGET into DRAFT if the draft model and main model are not compatible", - [](common_params & params, const std::string & tgt, const std::string & dft) { - params.speculative.draft.replacements.push_back({ tgt, dft }); - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); - add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + {"--spec-type"}, common_speculative_all_types_str(), string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", - common_speculative_type_to_str(params.speculative.type).c_str()), + common_speculative_type_name_str(params.speculative.types).c_str()), [](common_params & params, const std::string & value) { - if (value == "none") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - } else if (value == "ngram-cache") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; - } else if (value == "ngram-simple") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; - } else if (value == "ngram-map-k") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; - } else if (value == "ngram-map-k4v") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; - } else if (value == "ngram-mod") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; - } else { - throw std::invalid_argument("unknown speculative decoding type without draft model"); - } + const auto enabled_types = string_split(value, ','); + params.speculative.types = common_speculative_types_from_names(enabled_types); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE")); add_opt(common_arg( @@ -4075,7 +4044,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--spec-default"}, string_format("enable default speculative decoding config"), [](common_params & params) { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD }; params.speculative.ngram_mod.n_match = 24; params.speculative.ngram_mod.n_min = 48; params.speculative.ngram_mod.n_max = 64; diff --git a/common/common.cpp b/common/common.cpp index 793b8fee7..352af0b17 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + LOG_WRN("%s: the context does not support partial sequence removal\n", __func__); res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; goto done; } @@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode( return true; } + +size_t common_prompt_checkpoint::size() const { + return data_tgt.size() + data_dft.size(); +} + +bool common_prompt_checkpoint::empty() const { + return data_tgt.empty(); +} + +void common_prompt_checkpoint::clear() { + n_tokens = 0; + + pos_min = 0; + pos_max = 0; + + data_tgt.clear(); + data_dft.clear(); +} + +void common_prompt_checkpoint::update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max) { + this->n_tokens = n_tokens; + this->pos_min = pos_min; + this->pos_max = pos_max; +} + +void common_prompt_checkpoint::update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_tgt.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_dft.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_tgt.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags); + if (n != data_tgt.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n); + } +} + +void common_prompt_checkpoint::load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_dft.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags); + if (n != data_dft.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); + } +} diff --git a/common/common.h b/common/common.h index a564b3b8c..aafc376f2 100644 --- a/common/common.h +++ b/common/common.h @@ -295,8 +295,6 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; -struct common_ngram_mod; - // draft-model-based speculative decoding parameters struct common_params_speculative_draft { int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding @@ -307,11 +305,9 @@ struct common_params_speculative_draft { common_params_model mparams; - llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; - llama_context_params cparams; // these are the parameters for the draft llama_context - - int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K @@ -322,7 +318,6 @@ struct common_params_speculative_draft { std::vector devices; // devices to use for offloading - std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; }; @@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod { int32_t n_max = 64; int32_t n_min = 48; - - // shared instance of the ngram container for all speculative decoding contexts - std::shared_ptr obj; }; struct common_params_speculative_ngram_map { @@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache { }; struct common_params_speculative { - // TODO: become a vector in order to support "chains of speculators" - common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; + std::vector types = { COMMON_SPECULATIVE_TYPE_NONE }; common_params_speculative_draft draft; @@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std // "adamw" or "sgd" (case insensitive) enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); + +// +// prompt utils +// + +struct common_prompt_checkpoint { + int64_t n_tokens; + + llama_pos pos_min; + llama_pos pos_max; + + std::vector data_tgt; + std::vector data_dft; + + size_t size() const; + + bool empty() const; + void clear(); + + void update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max); + + void update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; + + void load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; +}; diff --git a/common/speculative.cpp b/common/speculative.cpp index e9fa751e2..e487e003d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,6 +10,7 @@ #include "sampling.h" #include +#include #include #include #include @@ -18,26 +19,15 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -const std::vector common_speculative_types = { - COMMON_SPECULATIVE_TYPE_NONE, - COMMON_SPECULATIVE_TYPE_DRAFT, - COMMON_SPECULATIVE_TYPE_EAGLE3, - COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, - COMMON_SPECULATIVE_TYPE_NGRAM_MOD, - COMMON_SPECULATIVE_TYPE_NGRAM_CACHE -}; - -const std::map common_speculative_type_from_name_map = { +const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, - {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, - {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, - {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, - {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, - {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} + {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, + {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, + {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, + {"ngram-mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, + {"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; struct common_speculative_config { @@ -115,12 +105,16 @@ static bool common_speculative_are_compatible( return true; } +using common_speculative_draft_params_vec = std::vector; + // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific -// in a subclass of common_speculative_state -struct common_speculative_state { - const enum common_speculative_type type; +// in a subclass of common_speculative_impl +struct common_speculative_impl { + const common_speculative_type type; + + uint32_t n_seq; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. @@ -138,65 +132,34 @@ struct common_speculative_state { int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. - common_speculative_state(enum common_speculative_type type) : type(type) {} + common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} - virtual ~common_speculative_state() = default; + virtual ~common_speculative_impl() = default; - virtual void begin(const llama_tokens & prompt) = 0; + virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; - virtual void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) = 0; + virtual bool process(const llama_batch & batch) = 0; - virtual void accept(uint16_t n_accepted) = 0; + virtual void draft(common_speculative_draft_params_vec & dparams) = 0; - virtual int32_t n_max(const common_params_speculative & params) const = 0; - virtual int32_t n_min(const common_params_speculative & params) const = 0; + virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; -struct common_speculative_checkpoint { - llama_pos pos_min = 0; - llama_pos pos_max = 0; +struct common_speculative_state_draft : public common_speculative_impl { + common_params_speculative_draft params; - int64_t n_tokens = 0; + llama_batch batch; - std::vector data; + std::vector smpls; - size_t size() const { - return data.size(); - } -}; - -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - bool use_ckpt = false; - common_speculative_checkpoint ckpt; - - common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements, - bool use_ckpt) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , use_ckpt(use_ckpt) + common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) + , params(params.draft) { + auto * ctx_dft = this->params.ctx_dft; + auto * ctx_tgt = this->params.ctx_tgt; + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; // TODO: optimize or pass from outside? // { @@ -214,7 +177,9 @@ struct common_speculative_state_draft : public common_speculative_state { // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } - { + + smpls.resize(n_seq); + for (auto & smpl : smpls) { common_params_sampling params; params.no_perf = false; params.top_k = 10; @@ -222,482 +187,321 @@ struct common_speculative_state_draft : public common_speculative_state { COMMON_SAMPLER_TYPE_TOP_K, }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); + smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } + throw std::runtime_error("draft model vocab type must match target model to use speculation"); + } + + if (n_seq != llama_n_seq_max(ctx_dft)) { + LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); + + throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); } } ~common_speculative_state_draft() override { - llama_perf_context_print(ctx_dft); - - llama_free(ctx_dft); - - common_sampler_free(smpl); - llama_batch_free(batch); } - void begin(const llama_tokens & /*prompt*/) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - size_t create_checkpoint(int n_tokens_prompt) { - int slot_id = 0; - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + bool process(const llama_batch & batch) override { + auto * ctx_dft = params.ctx_dft; - ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); - ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); - ckpt.n_tokens = n_tokens_prompt; - ckpt.data.resize(checkpoint_size); + const int ret = llama_decode(ctx_dft, batch); - const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + if (ret != 0) { + LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret); + + return false; } - LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, - ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); - return n; + return true; } - size_t restore_checkpoint() { - int slot_id = 0; - LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); - const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); - } - llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); + void draft(common_speculative_draft_params_vec & dparams) override { + auto & ctx_dft = params.ctx_dft; - return n; - } + common_batch_clear(batch); - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.draft; + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); - auto * spec = this; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + if (!dp.drafting) { + continue; + } - auto * mem_dft = llama_get_memory(ctx_dft); + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); - int reuse_i = 0; // index of part to be reused in prompt_dft - int reuse_n = 0; // length of part to be reused in prompt_dft - - const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); } - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - if (use_ckpt && i_start > 0) { - LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } + int i = 0; - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } + while (n_drafting > 0) { + int i_batch = 0; - if (use_ckpt) { - break; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", - __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.n_tokens > reuse_n) { - LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); - - reuse_i = 0; - reuse_n = 0; - - ckpt = {}; - } - - result.clear(); - result.reserve(sparams.n_max); - - if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (sparams.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - GGML_ASSERT(!use_ckpt); - - bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); - return; - } - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - if (use_ckpt) { - if (ckpt.n_tokens > 0) { - LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - restore_checkpoint(); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - } - } else { - const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - return; - } - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", - __func__, ret, prompt_cur.size()); - } - - if (use_ckpt) { - create_checkpoint(prompt_dft.size()); - } - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, ret, prompt_cur.size(), prompt_dft.size()); - } - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < sparams.n_max; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx_dft, 0, true); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; + } - const auto * cur_p = common_sampler_get_candidates(smpl, true); + auto * smpl = smpls[seq_id].get(); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + common_sampler_sample(smpl, ctx_dft, i_batch, true); + ++i_batch; + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < sparams.p_min) { + if (batch.n_tokens == 0) { break; } - result.push_back(id); - - if (sparams.n_max <= (int) result.size()) { - break; - } - - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; } - prompt_dft.push_back(id); + ++i; } - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t) sparams.n_max) { - result.resize(sparams.n_max); + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; } - } - if (result.size() < (size_t) sparams.n_min) { - result.clear(); + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } } } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; } }; -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} +struct common_speculative_state_eagle3 : public common_speculative_impl { + //common_params_speculative_eagle3 params; - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } + common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); - } - - void accept(uint16_t n_accepted) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop - GGML_UNUSED(n_accepted); } - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; + void draft(common_speculative_draft_params_vec & /*dparams*/) override { + // TODO: implement + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; // state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { +struct common_speculative_state_ngram_simple : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // shared across all sequences common_ngram_simple_config config; common_speculative_state_ngram_simple( - enum common_speculative_type type, + const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) + , params(params.ngram_simple) + , config(config) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop - GGML_UNUSED(n_accepted); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + *dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last); + } + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map config; +struct common_speculative_state_ngram_map_k : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // n_seq configs + std::vector config; common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map config) - : common_speculative_state(type), config(std::move(config)) {} - - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(config, prompt); - } - - void draft( const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(config, prompt_tgt, id_last, result); - GGML_UNUSED(params); + const common_ngram_map & config, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) + , params(params.ngram_map_k) { + for (uint32_t i = 0; i < n_seq; i++) { + this->config.push_back(config); + } } - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(config, n_accepted); + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT(seq_id < (llama_seq_id) n_seq); + + common_ngram_map_begin(config[seq_id], prompt); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_value; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_value; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result); + } + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT((seq_id < (llama_seq_id) config.size())); + + common_ngram_map_accept(config[seq_id], n_accepted); } }; -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; +struct common_speculative_state_ngram_mod : public common_speculative_impl { + common_params_speculative_ngram_mod params; - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; + // shared across all sequences + common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + struct seq_info { + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + }; + + std::vector sinfos; + + common_speculative_state_ngram_mod( + const common_params_speculative & params, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) + , params(params.ngram_mod) + , mod(params.ngram_mod.n_match, 4*1024*1024) + , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + + LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, + this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); + + if (this->params.n_match < 16) { + LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " + "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match); + } + + sinfos.resize(n_seq); } - void begin(const llama_tokens & prompt) override { - i_last = 0; + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + auto & sinfo = sinfos[seq_id]; - n_draft_last = 0; + sinfo.i_last = 0; + sinfo.n_draft_last = 0; const size_t n = mod.get_n(); - if (prompt.size() < n) { return; } @@ -706,7 +510,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.add(prompt.data() + i); } - i_last = prompt.size() - n; + sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); @@ -719,16 +523,17 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.ngram_mod; + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - n_draft_last = 0; + const auto & prompt = *dparams.prompt; - const size_t cur_len = prompt_tgt.size(); + sinfo.n_draft_last = 0; + + const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } @@ -736,24 +541,24 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { const size_t n = mod.get_n(); // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); + if (sinfo.i_last + 32 < cur_len) { + for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { + mod.add(prompt.data() + i); } - i_last = cur_len - n; + sinfo.i_last = cur_len - n; } - result.resize(n + sparams.n_max); + result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; + result[i] = prompt.at(cur_len - n + 1 + i); } - result[n - 1] = id_last; + result[n - 1] = dparams.id_last; - for (int i = 0; i < sparams.n_max; ++i) { + for (int i = 0; i < params.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { - if (i < sparams.n_min) { + if (i < params.n_min) { result.clear(); return; } @@ -771,65 +576,92 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); + sinfo.n_draft_last = result.size(); } - void accept(uint16_t n_accepted) override { - // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { - if (verbose) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); - } + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; + } - mod.reset(); - n_low = 0; - i_last = 0; - } - } else { - n_low = 0; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; } + + draft_one(seq_id, dp); } } - int32_t n_max(const common_params_speculative & params) const override { - return params.ngram_mod.n_max; - } + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + auto & sinfo = sinfos[seq_id]; - int32_t n_min(const common_params_speculative & params) const override { - return params.ngram_mod.n_min; + // compute acceptance fraction if we have a recorded draft length + if (sinfo.n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; + if (f_acc < 0.5) { + sinfo.n_low++; + if (sinfo.n_low >= 3) { + if (verbose) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); + } + + mod.reset(); + sinfo.n_low = 0; + sinfo.i_last = 0; + } + } else { + sinfo.n_low = 0; + } + } } }; -struct common_speculative_state_ngram_cache : public common_speculative_state { +struct common_speculative_state_ngram_cache : public common_speculative_impl { + common_params_speculative_ngram_cache params; + uint16_t n_draft; + bool save_dynamic; bool save_static; - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; + struct seq_info { + size_t cache_size = 0; // number of tokens in n-gram cache - size_t cache_size = 0; // number of tokens in n-gram cache + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + }; + + std::vector sinfos; common_speculative_state_ngram_cache( - const enum common_speculative_type type, + const common_params_speculative & params, + uint32_t n_seq, + uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) + bool save_dynamic, + bool save_static) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) + , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { + sinfos.resize(n_seq); + if (!path_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(path_static); + auto ngram_cache_static = common_ngram_cache_load(path_static); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_static = ngram_cache_static; + } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); @@ -838,7 +670,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { if (!path_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_dynamic = ngram_cache_dynamic; + } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); @@ -846,44 +682,48 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - if (cache_size < prompt_tgt.size() + 1) { + const auto & prompt = *dparams.prompt; + + if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); + tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); + for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { + tokens_new.push_back(prompt[j]); } - tokens_new.push_back(id_last); // add the last token + tokens_new.push_back(dparams.id_last); // add the last token - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + // Update context ngram cache with new dparams.prompt: + common_ngram_cache_update( + sinfo.ngram_cache_context, + LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; + sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); + inp.reserve(prompt.size() + 1); + for (size_t j = 0; j < prompt.size(); ++j) { + inp.push_back(prompt[j]); } - inp.push_back(id_last); + inp.push_back(dparams.id_last); - result.push_back(id_last); + result.push_back(dparams.id_last); - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); + common_ngram_cache_draft( + inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + sinfo.ngram_cache_context, + sinfo.ngram_cache_dynamic, + sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) @@ -891,24 +731,37 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return n_draft; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + draft_one(seq_id, dp); + } } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return 0; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; struct common_speculative { - std::vector> impls; // list of implementations to use and their states + common_speculative_draft_params_vec dparams; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + // list of implementations to use and their states + std::vector> impls; + + // which implementaion was used for a given seq_id + std::vector impl_last; }; static common_ngram_map get_common_ngram_map( @@ -923,45 +776,79 @@ static common_ngram_map get_common_ngram_map( } static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { + const common_speculative_config & config, + uint32_t n_seq, + const std::string & path_static, + const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } -std::string common_speculative_type_name_str() { +std::string common_speculative_type_name_str(const std::vector & types) { std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { + + for (size_t i = 0; i < types.size(); i++) { if (i > 0) { - result += ", "; + result += ","; } - result += common_speculative_type_to_str(common_speculative_types[i]); + result += common_speculative_type_to_str(types[i]); } return result; } -std::string common_speculative_type_to_str(enum common_speculative_type type) { +const char * common_speculative_all_types_str() { + static std::string all_types_str = []() { + std::vector types; + types.reserve(COMMON_SPECULATIVE_TYPE_COUNT); + for (int i = 0; i < COMMON_SPECULATIVE_TYPE_COUNT; i++) { + types.push_back((common_speculative_type) i); + } + return common_speculative_type_name_str(types); + }(); + return all_types_str.c_str(); +} + +std::string common_speculative_type_to_str(common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram-mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram-cache"; default: return "unknown"; } } -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { +std::vector common_speculative_types_from_names(const std::vector & names) { + std::vector types; + types.reserve(names.size()); + + for (const auto & name : names) { + auto type = common_speculative_type_from_name_map.find(name); + if (type != common_speculative_type_from_name_map.end()) { + if (type->second == COMMON_SPECULATIVE_TYPE_NONE) { + return std::vector { COMMON_SPECULATIVE_TYPE_NONE }; + } + types.push_back(type->second); + continue; + } + throw std::invalid_argument("unknown speculative type: " + name); + } + + return types; +} + +common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; @@ -969,34 +856,39 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } +static uint32_t common_get_enabled_speculative_configs(const std::vector & configs) { + uint32_t result = 0; + for (size_t i = 0; i < configs.size(); i++) { + result |= (1u << configs[i]); + } + return result; +} + // initialization of the speculative decoding system // -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - llama_context * ctx_dft = nullptr; - if (params.draft.model) { - ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } - } - +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { - bool has_draft = !params.draft.mparams.path.empty(); + uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types); + + bool has_draft = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT)); + bool has_draft_model = !params.draft.mparams.path.empty(); + + // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); - bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); - bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); - bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); + bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); + bool has_ngram_map_k = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K)); + bool has_ngram_map_k4v = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V)); + bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); - // In a more complex implementation we could use the same implementation but with different parameters. - // This was initially used in PR-18471 but removed to simplify the code. + // when adding a new type - update here the logic above + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + + // this list here defines the priority of the speculators + // the one with highest priority are listed first if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); @@ -1009,53 +901,43 @@ common_speculative * common_speculative_init( configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { - auto & sparams = params.ngram_mod; - - if (!sparams.obj) { - sparams.obj = std::make_shared(sparams.n_match, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, - sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024); - - if (sparams.n_match < 16) { - LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " - "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); - } - } - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_draft) { + if (!has_draft_model) { + LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); + has_draft = false; + } + } else if (has_draft_model) { + LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); + has_draft = true; + } + if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } + // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } } - std::vector> impls = {}; + std::vector> impls = {}; for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { - const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.draft.replacements, - /* .use_ckpt = */ use_ckpt - )); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { @@ -1069,27 +951,30 @@ common_speculative * common_speculative_init( /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple + /* .params = */ config.params, + /* .n_seq = */ n_seq, + /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config.type, config.params.ngram_map_k) - )); + impls.push_back( + std::make_unique( + config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod.obj); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod.obj)); + impls.push_back( + std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config); + auto state = create_state_ngram_cache( + config, n_seq, + params.ngram_cache.lookup_cache_static, + params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } @@ -1104,8 +989,9 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { + /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), - /* .curr_impl = */ nullptr, + /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; @@ -1119,65 +1005,128 @@ void common_speculative_free(common_speculative * spec) { delete spec; } -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { +common_speculative_draft_params & common_speculative_get_draft_params( + common_speculative * spec, + llama_seq_id seq_id) { + GGML_ASSERT(spec); + GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); + + return spec->dparams[seq_id]; +} + +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); + impl->begin(seq_id, prompt); impl->n_call_begin++; } } -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last) { - llama_tokens result; +bool common_speculative_process(common_speculative * spec, const llama_batch & batch) { + bool result = true; - spec->curr_impl = nullptr; // reset current implementation + if (spec == nullptr) { + return result; + } for (auto & impl : spec->impls) { - { - common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(params, prompt_tgt, id_last, result); - impl->n_call_draft++; - } - - { - const int n_min = impl->n_min(params); - - if (!result.empty() && (int) result.size() < n_min) { - LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min); - result.clear(); - } - } - - if (!result.empty()) { - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); // set current implementation for stats - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // we have a draft, so break out of the loop and return it. - } + result = result && impl->process(batch); } return result; } -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { +void common_speculative_draft(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + auto & dparams = spec->dparams; + + { + int n_drafting = 0; + + for (auto & dp : dparams) { + GGML_ASSERT(!dp.drafting || dp.result->empty()); + + if (dp.drafting) { + n_drafting++; + } + } + + if (n_drafting == 0) { + return; + } + } + + for (auto & impl : spec->impls) { + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(dparams); + impl->n_call_draft++; + } + + int n_drafting = 0; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + auto & result = *dp.result; + + // a new draft has been sampled + if (dp.drafting && !result.empty()) { + dp.drafting = false; + + if (dp.n_max > 0) { + if (!result.empty() && (int) result.size() > dp.n_max) { + LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max); + result.resize(dp.n_max); + } + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(), + impl.get()->n_call_draft, result.size()); + + // remember which implementation was used + spec->impl_last[seq_id] = impl.get(); + + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + } + } + + if (dp.drafting) { + n_drafting++; + } + } + + if (n_drafting == 0) { + break; + } + } + + // these sequences failed to generate a draft + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + if (dp.drafting) { + dp.drafting = false; + } + } +} + +void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } - common_speculative_state * impl = spec->curr_impl; + common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); @@ -1188,37 +1137,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->n_acc_tokens += n_accepted; } - impl->accept(n_accepted); + impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_max = 0; - for (const auto & impl : spec->impls) { - n_max = std::max(n_max, impl->n_max(params)); - } - - return n_max; -} - -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_min = 0; - for (const auto & impl : spec->impls) { - n_min = std::max(n_min, impl->n_min(params)); - } - - return n_min; -} - void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 147447631..51f0b059f 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -5,8 +5,14 @@ struct common_speculative; +// comma separated list the provided types +std::string common_speculative_type_name_str(const std::vector & types); + // comma separated list of all types -std::string common_speculative_type_name_str(); +const char * common_speculative_all_types_str(); + +// parse user provided types +std::vector common_speculative_types_from_names(const std::vector & names); // convert string to type enum common_speculative_type common_speculative_type_from_name(const std::string & name); @@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt); +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); void common_speculative_free(common_speculative * spec); +struct common_speculative_draft_params { + // this flag is used to chain the drafts through all the available implementations + // after the first successful draft from an implementation, we set it + // to false to prevent further drafts for that sequence + // at the end of the draft() call, all drafting flags will be reset to false + bool drafting = false; + + // overrides individual configurations (-1 disabled) + // can be used to constraint the max draft based on the remaining context size + int32_t n_max = -1; + + llama_pos n_past; + llama_token id_last; + + // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls + const llama_tokens * prompt; + + // the generated draft from the last _draft() call + llama_tokens * result; +}; + +common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); + // optionally call once at the beginning of a new generation -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); -// sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt, - llama_token id_last); +// process the batch and update the internal state of the speculative context +bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// informs the speculative decoder that n_accepted tokens were accepted by the target model -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// generate drafts for the sequences specified with `common_speculative_get_draft_params` +void common_speculative_draft(common_speculative * spec); -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params); -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params); +// informs the speculative context that n_accepted tokens were accepted by the target model +void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 5b61b62a1..5325bcc9e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,20 +13,6 @@ #include #include -struct spec_checkpoint { - int64_t n_tokens = 0; - - std::vector data; - - size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } -}; - int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); @@ -43,11 +29,6 @@ int main(int argc, char ** argv) { return 1; } - if (params.speculative.draft.mparams.path.empty()) { - LOG_ERR("%s: --model-draft is required\n", __func__); - return 1; - } - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -62,18 +43,11 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt->model(); ctx_tgt = llama_init_tgt->context(); - // check if the context supports partial sequence removal - const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); - const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); - - if (use_ckpt) { - LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); - } - const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model llama_model_ptr model_dft; + llama_context_ptr ctx_dft; // TODO: simplify this logic { @@ -81,9 +55,6 @@ int main(int argc, char ** argv) { auto params_dft = params; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -103,8 +74,19 @@ int main(int argc, char ** argv) { return 1; } - params.speculative.draft.model = model_dft.get(); - params.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + params.speculative.draft.ctx_tgt = ctx_tgt; + params.speculative.draft.ctx_dft = ctx_dft.get(); + } + + // check if the context supports partial sequence removal + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + + if (use_ckpt_tgt) { + LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); } // Tokenize the prompt @@ -136,6 +118,8 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + llama_seq_id seq_id = 0; + // ================================================ // everything until here is standard initialization // the relevant stuff for speculative decoding starts here @@ -146,7 +130,8 @@ int main(int argc, char ** argv) { common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling)); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1)); // note: keep the last token separate! llama_token id_last = inp.back(); @@ -160,16 +145,16 @@ int main(int argc, char ** argv) { // init the speculator const auto & params_spec = params.speculative; - struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + struct common_speculative * spec = common_speculative_init(params.speculative, 1); - common_speculative_begin(spec, prompt_tgt); + common_speculative_begin(spec, seq_id, prompt_tgt); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); size_t n_draft = 0; llama_tokens draft; - spec_checkpoint spec_ckpt; + common_prompt_checkpoint ckpt; const auto t_enc_end = ggml_time_us(); @@ -184,40 +169,57 @@ int main(int argc, char ** argv) { // from a cache or lookup tables. // if (draft.empty()) { + ckpt.update_pos( + prompt_tgt.size(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id)); + + if (use_ckpt_dft) { + ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + // generate a new draft - draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + common_speculative_get_draft_params(spec, seq_id) = { + /* .drafting = */ true, + /* .n_max = */ -1, + /* .n_past = */ n_past, + /* .id_last = */ id_last, + /* .prompt = */ &prompt_tgt, + /* .result = */ &draft, // output + }; + common_speculative_draft(spec); // save the original draft size n_draft = draft.size(); // save a checkpoint of the target context before evaluating the draft // this allows us to restore the state if partial draft acceptance occurs - if (!draft.empty() && use_ckpt) { - const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - spec_ckpt.data.resize(ckpt_size); + if (!draft.empty()) { + if (use_ckpt_tgt) { + ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + } - const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == ckpt_size); + { + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); - LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n", - spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); } } else { // we have a previous (partial) draft to reuse from checkpoint restoration - if (use_ckpt) { - GGML_ASSERT(!spec_ckpt.empty()); + if (use_ckpt_tgt) { + GGML_ASSERT(!ckpt.empty()); } } // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); @@ -225,9 +227,15 @@ int main(int argc, char ** argv) { llama_decode(ctx_tgt, batch_tgt); } + // evaluate the same batch with the draft model + { + // TODO: extend to support MTP, Eagle, etc. See server code for reference + llama_decode(ctx_dft.get(), batch_tgt); + } + // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(smpl.get())); } @@ -247,17 +255,24 @@ int main(int argc, char ** argv) { // check for partial draft acceptance: // if the context doesn't support partial sequence removal, restore the checkpoint // and make the accepted tokens the new partial draft for the next iteration - if (use_ckpt && ids.size() - 1 < draft.size()) { + if (use_ckpt_tgt && ids.size() - 1 < draft.size()) { LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); draft = std::move(ids); - const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == spec_ckpt.size()); + { + ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1); + } - prompt_tgt.resize(spec_ckpt.n_tokens); + { + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); + } + + prompt_tgt.resize(ckpt.n_tokens); smpl = std::move(smpl_save); n_past = (int) prompt_tgt.size(); @@ -265,7 +280,7 @@ int main(int argc, char ** argv) { continue; } - common_speculative_accept(spec, ids.size() - 1); + common_speculative_accept(spec, seq_id, ids.size() - 1); // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; @@ -305,7 +320,8 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1); } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/include/llama.h b/include/llama.h index 2ea226726..308e8ba9d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -858,6 +858,8 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 71a59395e..3d9714ab1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2475,11 +2475,29 @@ public: } if (need_alloc) { - mbuf_cur = std::move(mbuf); + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); - mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); - LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } } for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { @@ -2559,8 +2577,7 @@ public: mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); - auto & view = mbuf.org.back(); - view->buffer = rinfo.tensor->buffer; + ggml_backend_view_init(mbuf.org.back()); } for (auto & [buft, mbuf] : mbufs_new) { diff --git a/tools/cli/README.md b/tools/cli/README.md index bca4da7ef..02c564a29 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -195,11 +195,9 @@ | `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) | | `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) | | `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) | -| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) | | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | diff --git a/tools/server/README.md b/tools/server/README.md index 024760da6..77eddb335 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith | `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) | | `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) | | `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) | -| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) | | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 637f8d216..0a51390af 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -36,32 +36,6 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) { - if (pos_min == -1) { - pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); - } - if (pos_max == -1) { - pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); - } - - auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY; - if (on_device) { - flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE; - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags); - - ckpt.pos_min = pos_min; - ckpt.pos_max = pos_max; - ckpt.n_tokens = n_tokens; - ckpt.data.resize(checkpoint_size); - - const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); - } -} - // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -80,18 +54,19 @@ enum server_state { struct server_slot { int id; - llama_context * ctx = nullptr; - - common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; // speculative decoding + common_speculative * spec; + llama_tokens spec_draft; + llama_tokens spec_prompt; std::vector spec_i_batch; - server_prompt_checkpoint spec_ckpt; - common_speculative_ptr spec; + common_prompt_checkpoint spec_ckpt; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -135,21 +110,27 @@ struct server_slot { void prompt_save(server_prompt_cache & prompt_cache) const { GGML_ASSERT(prompt.data.size() == 0); - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0; - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + const size_t cur_size = cur_size_tgt + cur_size_dft; - auto * cur = prompt_cache.alloc(prompt, cur_size); + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft); if (cur == nullptr) { return; } - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + if (ctx_dft) { + llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE); + } } bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); + bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } @@ -164,7 +145,11 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + } + prompt.tokens.clear(); } @@ -222,7 +207,7 @@ struct server_slot { task_prev = std::move(task); task.reset(); - llama_set_sampler(ctx, id, nullptr); + llama_set_sampler(ctx_tgt, id, nullptr); // clear alora start alora_invocation_start = -1; @@ -259,7 +244,7 @@ struct server_slot { return !task->need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST); } bool can_batch_with(server_slot & other_slot) const { @@ -310,14 +295,10 @@ struct server_slot { return 0; } - const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative); - // determine the max draft that fits the current slot state - int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative); - // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + int n_draft_max = n_ctx - prompt.n_tokens() - 2; if (n_remaining > 0) { n_draft_max = std::min(n_draft_max, n_remaining - 1); @@ -325,61 +306,10 @@ struct server_slot { SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < n_draft_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min); - n_draft_max = 0; - } - return n_draft_max; } void update_batch(llama_batch & batch) { - const int n_draft_max = get_n_draft_max(); - if (n_draft_max > 0) { - GGML_ASSERT(can_speculate()); - - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - const llama_tokens & tokens = prompt.tokens.get_text_tokens(); - - const auto & params_spec = task->params.speculative; - - if (!spec_draft.empty()) { - // we have a previous (partial) draft to reuse - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - GGML_ASSERT(!spec_ckpt.empty()); - } - } else { - GGML_ASSERT(spec_i_batch.empty()); - - // generate a new draft - spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); - n_draft_total += spec_draft.size(); - - if (spec_draft.size() > (size_t) n_draft_max) { - SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); - spec_draft.resize(n_draft_max); - } - - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - const auto n_tokens = prompt.tokens.size(); - - //const int64_t t_start = ggml_time_us(); - - server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true); - - //const int64_t t_total = ggml_time_us() - t_start; - //printf("checkpoint total: %f ms\n", t_total / 1000.0); - - SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); - } - } - - GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); - } - if (spec_draft.empty()) { // no speculative decoding i_batch = batch.n_tokens; @@ -511,7 +441,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec.get()); + common_speculative_print_stats(spec); } json to_json(bool only_metrics = false) const { @@ -539,7 +469,7 @@ struct server_slot { }; if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true); res["generated"] = generated_text.empty() ? debug_generated_text : generated_text; } } @@ -550,8 +480,13 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); + + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); + } other.n_decoded = n_decoded; other.n_remaining = n_remaining; @@ -642,7 +577,8 @@ public: // only use these pointers outside of this class: // - when not in sleeping state // - and, with thread-safe APIs (e.g., tokenizer calls) - llama_model * model = nullptr; + llama_model * model_tgt = nullptr; + mtmd_context * mctx = nullptr; const llama_vocab * vocab = nullptr; @@ -669,11 +605,17 @@ private: // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; - llama_context * ctx = nullptr; + llama_context * ctx_tgt = nullptr; llama_batch batch {}; llama_model_ptr model_dft; + llama_context_ptr ctx_dft; + + common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + + common_speculative_ptr spec; bool add_bos_token = true; @@ -708,18 +650,12 @@ private: void destroy() { llama_init.reset(); - ctx = nullptr; - model = nullptr; + ctx_tgt = nullptr; + model_tgt = nullptr; mtmd_free(mctx); mctx = nullptr; - for (server_slot & slot : slots) { - if (slot.can_speculate()) { - slot.spec.reset(); - } - } - llama_batch_free(batch); } @@ -759,17 +695,17 @@ private: llama_init = common_init_from_params(params_base); - model = llama_init->model(); - ctx = llama_init->context(); + model_tgt = llama_init->model(); + ctx_tgt = llama_init->context(); - if (model == nullptr) { + if (model_tgt == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - vocab = llama_model_get_vocab(model); + vocab = llama_model_get_vocab(model_tgt); - n_ctx = llama_n_ctx(ctx); + n_ctx = llama_n_ctx(ctx_tgt); add_bos_token = llama_vocab_get_add_bos(vocab); @@ -781,9 +717,6 @@ private: auto params_dft = params_base; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -805,8 +738,13 @@ private: return false; } - params_base.speculative.draft.model = model_dft.get(); - params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); } std::string & mmproj_path = params_base.mmproj.path; @@ -826,7 +764,7 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.media_marker = get_media_marker(); - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); return false; @@ -844,7 +782,7 @@ private: } } - if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -856,14 +794,14 @@ private: } } - if (llama_model_n_swa(model) == 0) { + if (llama_model_n_swa(model_tgt) == 0) { if (params_base.swa_full) { params_base.swa_full = false; SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled"); } } - n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model); + n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt); // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; @@ -871,9 +809,9 @@ private: // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model_tgt); - int n_ctx_slot = llama_n_ctx_seq(ctx); + int n_ctx_slot = llama_n_ctx_seq(ctx_tgt); if (n_ctx_slot > n_ctx_train) { SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); n_ctx_slot = n_ctx_train; @@ -881,12 +819,12 @@ private: slots.clear(); - const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt); + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); } @@ -895,27 +833,33 @@ private: slots.emplace_back(); } + // try speculative decoding + if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + try { + spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel)); + } catch (const std::exception & e) { + SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + } + } + + if (spec) { + SRV_INF("%s", "speculative decoding context initialized\n"); + } else { + ctx_dft.reset(); + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - - slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.id = i; + slot.ctx_tgt = ctx_tgt; + slot.ctx_dft = ctx_dft.get(); + slot.spec = spec.get(); + slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - // try speculative decoding - if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); - - if (slot.spec) { - SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } - } - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int id_slot) { @@ -946,7 +890,7 @@ private: // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_batch = llama_n_batch(ctx_tgt); batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -990,8 +934,9 @@ private: // unlike load_model(), this is only called once during initialization bool init() { - GGML_ASSERT(ctx != nullptr); - GGML_ASSERT(model != nullptr); + GGML_ASSERT(ctx_tgt != nullptr); + GGML_ASSERT(model_tgt != nullptr); + GGML_ASSERT(!sleeping); // wiring up server queues @@ -1037,7 +982,7 @@ private: common_chat_templates_ptr chat_templates; try { - chat_templates = common_chat_templates_init(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template); LOG_INF("%s: chat template, example_format: '%s'\n", __func__, common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); @@ -1300,7 +1245,7 @@ private: } } - if (!task.tokens.validate(ctx)) { + if (!task.tokens.validate(ctx_tgt)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1310,7 +1255,7 @@ private: // initialize samplers if (task.need_sampling()) { try { - slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling)); } catch (std::exception & e) { std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); @@ -1324,16 +1269,16 @@ private: backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); + backend_sampling &= !(slot.can_speculate()); // TODO: getting pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_pre_sample_logits; // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { - llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); + llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get())); } else { - llama_set_sampler(ctx, slot.id, nullptr); + llama_set_sampler(ctx_tgt, slot.id, nullptr); } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); @@ -1512,13 +1457,13 @@ private: result.probs.push_back({ cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), + common_token_to_piece(ctx_tgt, cur_p->data[i].id, special), cur_p->data[i].p }); } } else { // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); @@ -1536,7 +1481,7 @@ private: for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), + common_token_to_piece(ctx_tgt, cur[i].id, special), cur[i].p }); } @@ -1639,7 +1584,7 @@ private: res->tokens = std::move(slot.generated_tokens); } res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); + res->prompt = slot.task->tokens.detokenize(ctx_tgt, true); res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; @@ -1662,7 +1607,7 @@ private: // populate res.probs_output if (slot.task->params.sampling.n_probs > 0) { if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); res->probs_output = std::vector( @@ -1687,7 +1632,7 @@ private: res->n_tokens = slot.task->n_tokens(); res->res_type = slot.task->params.res_type; - const int n_embd_out = llama_model_n_embd_out(model); + const int n_embd_out = llama_model_n_embd_out(model_tgt); std::vector embd_res(n_embd_out, 0.0f); @@ -1697,10 +1642,10 @@ private: } const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); + if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(slot.ctx_tgt, i); } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]); } if (embd == nullptr) { @@ -1711,7 +1656,7 @@ private: } // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) { common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; @@ -1736,9 +1681,9 @@ private: continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]); if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); + embd = llama_get_embeddings_ith(ctx_tgt, i); } if (embd == NULL) { @@ -1843,18 +1788,22 @@ private: const auto & cur = slot.prompt.checkpoints.front(); SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } auto & cur = slot.prompt.checkpoints.emplace_back(); - server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max); + + cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + + cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2009,7 +1958,7 @@ private: std::string filepath = task.slot_action.filepath; const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2048,7 +1997,7 @@ private: llama_tokens tokens; tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -2207,8 +2156,13 @@ private: SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + } // add generated tokens to cache // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 @@ -2236,12 +2190,10 @@ private: // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + std::vector generating; + std::vector drafting; - // first, add sampled tokens from any ongoing sequences + // determine which slots are generating and drafting for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; @@ -2254,12 +2206,103 @@ private: continue; } + generating.push_back(&slot); + + if (spec) { + common_speculative_get_draft_params(spec.get(), slot.id).drafting = false; + + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + const int n_draft_max = slot.get_n_draft_max(); + + if (n_draft_max > 0) { + GGML_ASSERT(slot.can_speculate()); + + if (!slot.spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (use_ckpt_tgt) { + GGML_ASSERT(!slot.spec_ckpt.empty()); + } + } else { + GGML_ASSERT(slot.spec_i_batch.empty()); + + slot.spec_ckpt.update_pos( + slot.prompt.n_tokens(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + + if (use_ckpt_dft) { + slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + + slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); + + common_speculative_get_draft_params(spec.get(), slot.id) = { + /* .drafting = */ true, + /* .n_max = */ n_draft_max, + /* .n_past = */ slot.prompt.n_tokens(), + /* .id_last = */ slot.sampled, + /* .prompt = */ &slot.spec_prompt, + /* .result = */ &slot.spec_draft, + }; + + drafting.push_back(&slot); + } + } + } + } + + // generate the actual drafts (if any) + { + common_speculative_draft(spec.get()); + } + + // make checkpoints if needed + for (auto * slot_ptr : drafting) { + auto & slot = *slot_ptr; + + auto & draft = slot.spec_draft; + auto & ckpt = slot.spec_ckpt; + + slot.n_draft_total += draft.size(); + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + if (ctx_dft) { + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); + } + + if (!draft.empty()) { + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + if (use_ckpt_tgt) { + //const int64_t t_start = ggml_time_us(); + + ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); + + SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n", + ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(), + (float) ckpt.size() / 1024 / 1024, + (float) ckpt.data_dft.size() / 1024 / 1024); + } + } + } + + // update the batch with the sampled/drafted tokens + for (auto * slot_ptr : generating) { + auto & slot = *slot_ptr; + slot.update_batch(batch); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); + int32_t n_batch = llama_n_batch(ctx_tgt); + int32_t n_ubatch = llama_n_ubatch(ctx_tgt); float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2303,12 +2346,12 @@ private: /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } } else { // all for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } }*/ @@ -2327,7 +2370,7 @@ private: } // TODO: support memory-less logits computation - if (slot.task->need_logits() && !llama_get_memory(ctx)) { + if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2379,7 +2422,7 @@ private: const auto n_cache_reuse = slot.task->params.n_cache_reuse; const bool can_cache_reuse = - llama_memory_can_shift(llama_get_memory(ctx)) && + llama_memory_can_shift(llama_get_memory(ctx_tgt)) && !slot.prompt.tokens.has_mtmd; if (!can_cache_reuse && n_cache_reuse > 0) { @@ -2413,13 +2456,18 @@ private: if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str()); //} const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); + } for (size_t i = 0; i < n_match; i++) { slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); @@ -2446,7 +2494,7 @@ private: const auto pos_min_thold = std::max(0, pos_next - n_swa); if (n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -2475,14 +2523,14 @@ private: { const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss0 << piece; st0 << std::setw(8) << token; } { const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss1 << piece; st1 << std::setw(8) << token; } @@ -2514,18 +2562,13 @@ private: if (!do_reset) { // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); - n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024); - } + it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); } if (do_reset) { @@ -2542,7 +2585,7 @@ private: for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -2582,14 +2625,18 @@ private: SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); slot.prompt_clear(true); // there is no common part left slot.n_prompt_tokens_cache = 0; - } + } else { + if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { + GGML_ABORT("failed to truncate draft context\n"); + } + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -2615,7 +2662,7 @@ private: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || (n_swa > 0)); bool has_mtmd = false; @@ -2624,7 +2671,7 @@ private: while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -2632,6 +2679,16 @@ private: continue; } + if (ctx_dft) { + // TODO: in the future, figure out how to infuse target embeddings to the images + // for now, we skip this for simplicity + // maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above? + res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + GGML_ABORT("failed to process multi-modal data on draft context\n"); + } + } + slot.n_prompt_tokens_processed += n_tokens_out; // add the image chunk to cache @@ -2733,8 +2790,8 @@ private: SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); // no need for empty or small checkpoints do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); @@ -2765,9 +2822,14 @@ private: SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + if (slot_batched) { // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); + common_set_adapter_lora(ctx_tgt, slot_batched->lora); // if the lora is temporarily disabled for an alora, re-enable it // for next time @@ -2776,7 +2838,7 @@ private: slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); } if (batch.n_tokens == 0) { @@ -2805,7 +2867,7 @@ private: batch.logits + i, }; - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx_tgt, batch_view); metrics.on_decoded(slots); @@ -2858,11 +2920,63 @@ private: continue; // continue loop of n_batch } + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + // + // | spec type | need re-eval | + // | --- | --- | + // | draft model | no | because the draft model does not use embeddings from the target + // | MTP (std) | yes | + // | MTP Gemma4 | no | because the KV cache is shared + // | Eagle3 | yes | + // | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982 + // + // note: this logic is now moved in `common_speculative_process()` + // keeping the sketch here until for a bit, until the logic is finalized + // + //if (ctx_dft) { + // // TODO: update as needed for MTP, Eagle3, etc. + // const bool need_tgt_embd = false; + + // if (need_tgt_embd) { + // llama_synchronize(ctx_tgt); + // } + + // // the logic here varies depending on the speculative decoding method + // // - some draft contexts require embeddings from the target context, others don't + // // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings + // // TODO: extract this in a function ? + // { + // // TODO: hook the embeddings from the last target batch here + // if (llama_model_has_encoder(model_dft.get())) { + // //llama_encode(ctx_dft, ...); + + // GGML_ABORT("not implemented yet\n"); + // } + + // const int ret = llama_decode(ctx_dft.get(), batch_view); + + // if (ret != 0) { + // SRV_ERR("failed to decode draft batch, ret = %d\n", ret); + + // // TODO: handle error + // break; + // } + // } + //} + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + break; + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); + n_batch = llama_n_batch(ctx_tgt); // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -2921,7 +3035,7 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2933,7 +3047,7 @@ private: const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); slot.i_batch = -1; @@ -2954,7 +3068,7 @@ private: completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2985,23 +3099,23 @@ private: // verify and try to accept the draft { - const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); } @@ -3011,16 +3125,19 @@ private: const auto & ckpt = slot.spec_ckpt; - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", - ckpt.pos_min, ckpt.pos_max, ckpt.size()); + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); } - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + } slot.prompt.tokens.keep_first(ckpt.n_tokens); slot.smpl = std::move(smpl_save); @@ -3033,7 +3150,7 @@ private: SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } - common_speculative_accept(slot.spec.get(), accepted.size() - 1); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); slot.spec_draft = std::move(accepted); } @@ -3055,13 +3172,16 @@ private: slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3113,7 +3233,7 @@ void server_context::terminate() { } llama_context * server_context::get_llama_context() const { - return impl->ctx; + return impl->ctx_tgt; } server_response_reader server_context::get_response_reader() { @@ -3123,8 +3243,8 @@ server_response_reader server_context::get_response_reader() { server_context_meta server_context::get_meta() const { auto bos_id = llama_vocab_bos(impl->vocab); auto eos_id = llama_vocab_eos(impl->vocab); - auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : ""; - auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : ""; + auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : ""; + auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : ""; return server_context_meta { /* build_info */ std::string(llama_build_info()), @@ -3137,7 +3257,7 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* json_webui_settings */ impl->json_webui_settings, /* slot_n_ctx */ impl->get_slot_n_ctx(), - /* pooling_type */ llama_pooling_type(impl->ctx), + /* pooling_type */ llama_pooling_type(impl->ctx_tgt), /* chat_params */ impl->chat_params, /* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()), @@ -3155,10 +3275,10 @@ server_context_meta server_context::get_meta() const { /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), - /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), - /* model_n_embd_inp */ llama_model_n_embd(impl->model), - /* model_n_params */ llama_model_n_params(impl->model), - /* model_size */ llama_model_size(impl->model), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt), + /* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt), + /* model_n_params */ llama_model_n_params(impl->model_tgt), + /* model_size */ llama_model_size(impl->model_tgt), }; } @@ -4045,7 +4165,7 @@ void server_routes::init_routes() { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.get_new_id(); task.tokens = std::move(tmp); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 45e5168fa..b9b4a704a 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl( params.speculative = defaults.speculative; + // TODO: to keep things simple, we disable speculative parameter adjustments for now +#if 0 // TODO: for now, be able to adjust only the draft-model based speculative parameters params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min); params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max); @@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl( params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0); params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0); -#if 0 // for debugging and research purposes params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); @@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const { return res; } -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); @@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t } } - std::vector state_data; + std::vector state_data_tgt; + std::vector state_data_dft; // check if we can allocate enough memory for the new state try { - state_data.resize(state_size); + state_data_tgt.resize(state_size_tgt); + state_data_dft.resize(state_size_dft); } catch (const std::bad_alloc & e) { SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); @@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return nullptr; } - auto & cur = states.emplace_back(); - cur = { + states.push_back({ /*.tokens =*/ prompt.tokens.clone(), - /*.data =*/ std::move(state_data), + /*.data =*/ { + /*.main =*/ std::move(state_data_tgt), + /*.drft =*/ std::move(state_data_dft), + }, /*.checkpoints =*/ prompt.checkpoints, - }; + }); - return &cur; + return &states.back(); } -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins @@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok if (it_best != states.end()) { SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); + { + auto & data = it_best->data.main; - return false; + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); } - it_best->data.clear(); - it_best->data.shrink_to_fit(); + { + auto & data = it_best->data.drft; + + if (!data.empty()) { + GGML_ASSERT(ctx_dft); + + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); + } + } prompt = std::move(*it_best); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 289e1fb8d..64bdecd79 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - int64_t n_tokens; - - std::vector data; +struct server_prompt_data { + std::vector main; + std::vector drft; size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } - - void clear() { - pos_min = 0; - pos_max = 0; - n_tokens = 0; - data.clear(); + return main.size() + drft.size(); } }; struct server_prompt { server_tokens tokens; - std::vector data; + server_prompt_data data; - std::list checkpoints; + std::list checkpoints; size_t size() const { - size_t res = data.size(); + size_t res = 0; - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); + res += data.size(); + + for (const auto & ckpt : checkpoints) { + res += ckpt.size(); } return res; @@ -614,7 +601,7 @@ struct server_prompt { return server_prompt { tokens.clone(), data, - checkpoints + checkpoints, }; } }; @@ -637,9 +624,9 @@ struct server_prompt_cache { size_t n_tokens() const; - server_prompt * alloc(const server_prompt & prompt, size_t state_size); + server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft); - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot); void update(); }; diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index eebd3cc8f..84cd77e6f 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -5,7 +5,7 @@ from utils import * server = ServerPreset.stories15m_moe() -MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf" def create_server(): global server