spec : parallel drafting support (#22838)

* spec : refactor

* spec : drop support for incompatible vocabs

* spec : update common_speculative_init()

* cont : pass seq_id

* cont : dedup ctx_seq_rm_type

* server : sketch the ctx_dft decode loop

* server : draft prompt cache and checkpoints

* server : improve ctx names

* server, spec : transition to unified spec context

* cont : sync main and drft contexts

* cont : async drft eval when possible

* cont : handle non-ckpt models

* cont : pass correct n_past for drafting

* cont : process images throught the draft context

* spec : handle draft running out of context

* server : fix mtmd draft processing

* server : fix URL for draft model

* server : add comment

* server : clean-up + dry

* speculative-simple : update

* spec : fix n_past type

* server : fix slot ctx_drft ptr

* tools : update readme

* naming : improve consistency

* spec : refactor for multi-sequence speculative context

* cont : prepare params

* cont : prepare params

* spec : support parallel drafts

* server : support parallel drafting

* llama : reuse device buffers when possible

* server, spec : clean-up

* cont : clean-up

* cont : minor

* spec : reset `drafting` flag at the end

* spec : introduce `common_speculative_process()`

* spec : allow for multiple spec types (chain of speculators)

* replace old type field of type common_speculative_type in the
  common_params_speculative struct with a vector to allow multiple
  types to be specified

* introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>)
  to figure out which implementations the user has enabled

* introduce common_speculative_type_from_names(const std::vector<std::string> & names)
  to parse the already user provided spec types

* all speculators run sequentially, best one wins (we verify its drafted tokens)

* maximize expected accepted tokens for current round by calculating the
  product between the probability of accepting current token (n_acc_tokens / n_gen_drafts)
  and the draft's length

---------

Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
This commit is contained in:
Georgi Gerganov 2026-05-11 19:09:43 +03:00 committed by GitHub
parent 928b486b0c
commit 68e7ea3eab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1286 additions and 1071 deletions

View file

@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.types", common_speculative_type_name_str(speculative.types)},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.types", common_speculative_type_name_str(speculative.types)},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl(
params.speculative = defaults.speculative;
// TODO: to keep things simple, we disable speculative parameter adjustments for now
#if 0
// TODO: for now, be able to adjust only the draft-model based speculative parameters
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl(
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
#if 0
// for debugging and research purposes
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const {
return res;
}
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
// first check if the current state is contained fully in the cache
for (auto it = states.begin(); it != states.end(); ++it) {
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}
}
std::vector<uint8_t> state_data;
std::vector<uint8_t> state_data_tgt;
std::vector<uint8_t> state_data_dft;
// check if we can allocate enough memory for the new state
try {
state_data.resize(state_size);
state_data_tgt.resize(state_size_tgt);
state_data_dft.resize(state_size_dft);
} catch (const std::bad_alloc & e) {
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
auto & cur = states.emplace_back();
cur = {
states.push_back({
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.data =*/ {
/*.main =*/ std::move(state_data_tgt),
/*.drft =*/ std::move(state_data_dft),
},
/*.checkpoints =*/ prompt.checkpoints,
};
});
return &cur;
return &states.back();
}
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
{
auto & data = it_best->data.main;
return false;
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
it_best->data.clear();
it_best->data.shrink_to_fit();
{
auto & data = it_best->data.drft;
if (!data.empty()) {
GGML_ASSERT(ctx_dft);
const size_t size = data.size();
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
if (n != size) {
SRV_WRN("failed to restore state with size %zu\n", size);
return false;
}
data.clear();
data.shrink_to_fit();
}
}
prompt = std::move(*it_best);