mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-18 06:19:19 +00:00
spec : parallel drafting support (#22838)
* spec : refactor * spec : drop support for incompatible vocabs * spec : update common_speculative_init() * cont : pass seq_id * cont : dedup ctx_seq_rm_type * server : sketch the ctx_dft decode loop * server : draft prompt cache and checkpoints * server : improve ctx names * server, spec : transition to unified spec context * cont : sync main and drft contexts * cont : async drft eval when possible * cont : handle non-ckpt models * cont : pass correct n_past for drafting * cont : process images throught the draft context * spec : handle draft running out of context * server : fix mtmd draft processing * server : fix URL for draft model * server : add comment * server : clean-up + dry * speculative-simple : update * spec : fix n_past type * server : fix slot ctx_drft ptr * tools : update readme * naming : improve consistency * spec : refactor for multi-sequence speculative context * cont : prepare params * cont : prepare params * spec : support parallel drafts * server : support parallel drafting * llama : reuse device buffers when possible * server, spec : clean-up * cont : clean-up * cont : minor * spec : reset `drafting` flag at the end * spec : introduce `common_speculative_process()` * spec : allow for multiple spec types (chain of speculators) * replace old type field of type common_speculative_type in the common_params_speculative struct with a vector to allow multiple types to be specified * introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>) to figure out which implementations the user has enabled * introduce common_speculative_type_from_names(const std::vector<std::string> & names) to parse the already user provided spec types * all speculators run sequentially, best one wins (we verify its drafted tokens) * maximize expected accepted tokens for current round by calculating the product between the probability of accepting current token (n_acc_tokens / n_gen_drafts) and the draft's length --------- Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
This commit is contained in:
parent
928b486b0c
commit
68e7ea3eab
14 changed files with 1286 additions and 1071 deletions
|
|
@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.types", common_speculative_type_name_str(speculative.types)},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
|
|
@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const {
|
|||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.types", common_speculative_type_name_str(speculative.types)},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
|
|
@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl(
|
|||
|
||||
params.speculative = defaults.speculative;
|
||||
|
||||
// TODO: to keep things simple, we disable speculative parameter adjustments for now
|
||||
#if 0
|
||||
// TODO: for now, be able to adjust only the draft-model based speculative parameters
|
||||
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
|
||||
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
|
||||
|
|
@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
|
||||
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
|
||||
|
||||
#if 0
|
||||
// for debugging and research purposes
|
||||
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
||||
|
||||
|
|
@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const {
|
|||
return res;
|
||||
}
|
||||
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
|
||||
// first check if the current state is contained fully in the cache
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
|
|
@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
std::vector<uint8_t> state_data_tgt;
|
||||
std::vector<uint8_t> state_data_dft;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
state_data_tgt.resize(state_size_tgt);
|
||||
state_data_dft.resize(state_size_dft);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
|
|
@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto & cur = states.emplace_back();
|
||||
cur = {
|
||||
states.push_back({
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.data =*/ {
|
||||
/*.main =*/ std::move(state_data_tgt),
|
||||
/*.drft =*/ std::move(state_data_dft),
|
||||
},
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
});
|
||||
|
||||
return &cur;
|
||||
return &states.back();
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
|
|
@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
|||
if (it_best != states.end()) {
|
||||
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
|
||||
return false;
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
{
|
||||
auto & data = it_best->data.drft;
|
||||
|
||||
if (!data.empty()) {
|
||||
GGML_ASSERT(ctx_dft);
|
||||
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue