fix speculative decoding

This commit is contained in:
Li, Zonghang 2025-06-13 08:18:12 +04:00
parent e50b3aa473
commit dc875bbef9
4 changed files with 75 additions and 28 deletions

View file

@ -12,7 +12,7 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
@ -65,23 +65,29 @@ int main(int argc, char ** argv) {
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the draft model
// make a hard copy of params to use for the draft model
gpt_params params_draft = params;
params_draft.model = params_draft.model_draft;
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
params_draft.n_world = 1; // do not split the draft model across devices
params_draft.rank = 0; // always load the draft model on the head device
std::fill_n(params_draft.n_layer_window, params.n_world, 0);
if (params_draft.draft_cpuparams.n_threads > 0) {
params_draft.cpuparams.n_threads = params_draft.draft_cpuparams.n_threads;
}
params_draft.cpuparams_batch.n_threads = params_draft.draft_cpuparams_batch.n_threads;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_draft);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
// load the target model
llama_init_result llama_init_tgt = llama_init_from_gpt_params(params);
model_tgt = llama_init_tgt.model;
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.draft_cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
@ -161,9 +167,6 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
@ -180,8 +183,6 @@ int main(int argc, char ** argv) {
// target model sampling context (reuse the llama_context's sampling instance)
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
struct llama_sampler * softmax = llama_sampler_init_softmax();
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
@ -258,10 +259,13 @@ int main(int argc, char ** argv) {
float r = u_dist(rng);
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// if (dist_tgt.size > dist_dft.size) {
// LOG_ERR("dist_tgt.size (%zu) must be less than or equal to dist_dft.size (%zu)\n", dist_tgt.size, dist_dft.size);
// GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// }
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
for (size_t i = 0; i < dist_tgt.size && i < dist_dft.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
p_tgt = dist_tgt.data[i].p;
}
@ -406,7 +410,6 @@ int main(int argc, char ** argv) {
{
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
// TODO: simplify
{
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
@ -418,6 +421,12 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
// notify other devices to manage the KV cache in the same way
llama_send_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_send_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_send_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@ -435,7 +444,6 @@ int main(int argc, char ** argv) {
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
++n_past_dft;
@ -575,12 +583,13 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_cache_seq_keep (ctx_tgt, 0);
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_kv_cache_seq_cp (ctx_tgt, 0, s, -1, -1);
llama_send_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}
@ -612,7 +621,7 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("draft:\n\n");
// TODO: print sampling/grammar timings for all drafts
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
@ -624,7 +633,6 @@ int main(int argc, char ** argv) {
gpt_sampler_free(drafts[s].smpl);
}
llama_sampler_free(softmax);
llama_batch_free(batch_dft);
llama_free(ctx_tgt);