server : implement universal assisted decoding (#12635)

* llama-server : implement universal assisted decoding

* Erase prompt tail for kv-cache

* set vocab_dft_compatible in common_speculative

* rename ctx_main to ctx_tgt

* move vocab_dft_compatible to spec struct

* clear mem_dft, remove mem

* detokenize id_last for incompatible models

* update comment

* add --spec-replace flag

* accept special tokens when translating between draft/main models

* Escape spec-replace

* clamp draft result to size to params.n_draft

* fix comment

* clean up code

* restore old example

* log common_speculative_are_compatible in speculative example

* fix

* Update common/speculative.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/speculative.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/speculative.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
g2mt 2025-07-31 05:25:23 -07:00 committed by GitHub
parent c1dacaa99b
commit 94933c8c2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 168 additions and 62 deletions

View file

@ -977,6 +977,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}
if (!params.kv_overrides.empty()) {
@ -3249,6 +3253,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.model.path = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
{"--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(

View file

@ -201,6 +201,7 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V

View file

@ -1,30 +1,39 @@
#include "speculative.h"
#include "ggml.h"
#include "llama.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include <cstring>
#include <algorithm>
#include <map>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct common_speculative {
struct llama_context * ctx;
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
struct llama_context * ctx_dft;
struct common_sampler * smpl;
llama_batch batch;
llama_tokens prompt;
llama_tokens prompt_dft;
bool vocab_dft_compatible = true; // whether retokenization is needed
std::map<std::string, std::string> tgt_dft_replacements = {};
};
struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
/* .prompt_dft = */ {},
/* .vocab_dft_compatible = */ false,
};
// TODO: optimize or pass from outside?
@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
}
#endif
result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
return result;
}
@ -90,31 +102,32 @@ bool common_speculative_are_compatible(
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
if (
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
) {
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
return false;
}
{
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
@ -122,8 +135,8 @@ bool common_speculative_are_compatible(
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str());
return false;
@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
return true;
}
void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest) {
spec->tgt_dft_replacements[source] = dest;
}
static std::string replace_to_dft(
struct common_speculative * spec,
const std::string& input) {
std::string result = input;
for (const auto & pair : spec->tgt_dft_replacements) {
size_t pos = result.find(pair.first);
while (pos != std::string::npos) {
result.replace(pos, pair.first.length(), pair.second);
pos = result.find(pair.first, pos + pair.second.length());
}
}
return result;
}
static std::string replace_to_tgt(
struct common_speculative * spec,
const std::string& input) {
std::string result = input;
for (const auto& pair : spec->tgt_dft_replacements) {
size_t pos = result.find(pair.second);
while (pos != std::string::npos) {
result.replace(pos, pair.second.length(), pair.first);
pos = result.find(pair.second, pos + pair.first.length());
}
}
return result;
}
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;
auto & prompt_dft = spec->prompt_dft;
auto * mem = llama_get_memory(ctx);
auto * mem_dft = llama_get_memory(ctx_dft);
int reuse_i = 0;
int reuse_n = 0;
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
llama_tokens prompt_tgt_draft_model;
if (!spec->vocab_dft_compatible) {
std::string text;
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
text = replace_to_dft(spec, text);
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
const auto * model_tgt = llama_get_model(ctx_tgt);
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
text.resize(-n_chars);
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
text = replace_to_dft(spec, text);
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
id_last = common_tokenize(ctx_dft, text, false, true)[0];
}
// prompt_tgt's tokens will always be compatible with ctx_dft
const llama_tokens &prompt_tgt =
spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
i + cur < (int) prompt.size() &&
prompt_tgt[i_start + cur] == prompt[i + cur]) {
i + cur < (int) prompt_dft.size() &&
prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
cur++;
}
@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
llama_tokens result;
result.reserve(params.n_draft);
if (reuse_n == 0) {
llama_memory_clear(mem, false);
prompt.clear();
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
if (params.n_draft <= (int) result.size()) {
break;
@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
}
if (reuse_i > 0) {
llama_memory_seq_rm (mem, 0, 0, reuse_i);
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_memory_seq_rm (mem, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
if (reuse_n < (int) prompt_dft.size()) {
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
@ -214,28 +286,28 @@ llama_tokens common_speculative_gen_draft(
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
prompt.push_back(prompt_tgt[i]);
prompt_dft.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
llama_decode(ctx_dft, batch);
}
const llama_pos n_past = prompt.size();
const llama_pos n_past = prompt_dft.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt.push_back(id_last);
prompt_dft.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
llama_decode(ctx, batch);
llama_decode(ctx_dft, batch);
common_sampler_reset(smpl);
@ -243,13 +315,13 @@ llama_tokens common_speculative_gen_draft(
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx, 0, true);
common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
llama_decode(ctx_dft, batch);
prompt.push_back(id);
prompt_dft.push_back(id);
}
if (!spec->vocab_dft_compatible) {
std::string detokenized = common_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(spec, detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t)params.n_draft) {
result.resize(params.n_draft);
}
}
return result;
}

View file

@ -12,7 +12,10 @@ struct common_speculative_params {
float p_min = 0.75f; // min probability required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
);
void common_speculative_free(struct common_speculative * spec);
@ -20,6 +23,10 @@ bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,

View file

@ -65,7 +65,7 @@ int main(int argc, char ** argv) {
ctx_dft = llama_init_dft.context.get();
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
return 1;
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
}
// Tokenize the prompt
@ -130,7 +130,10 @@ int main(int argc, char ** argv) {
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
params_spec.p_min = p_min;
struct common_speculative * spec = common_speculative_init(ctx_dft);
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
for (auto &pair : params.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
}
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);

View file

@ -1929,6 +1929,7 @@ struct server_context {
mtmd_context * mctx = nullptr;
const llama_vocab * vocab = nullptr;
bool vocab_dft_compatible = true;
llama_model * model_dft = nullptr;
@ -2019,10 +2020,9 @@ struct server_context {
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
return false;
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
if (!vocab_dft_compatible) {
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
@ -2112,11 +2112,14 @@ struct server_context {
return;
}
slot.spec = common_speculative_init(slot.ctx_dft);
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
}
for (auto &pair : params_base.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);