fix speculative decoding

This commit is contained in:
Li, Zonghang 2025-06-13 08:18:12 +04:00
parent e50b3aa473
commit dc875bbef9
4 changed files with 75 additions and 28 deletions

View file

@ -2,6 +2,8 @@
BUILD_TARGETS = \
llama-server \
llama-cli \
llama-speculative \
llama-gguf-split \
profile-tool
# BUILD_TARGETS = \

View file

@ -12,7 +12,7 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
@ -65,23 +65,29 @@ int main(int argc, char ** argv) {
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the draft model
// make a hard copy of params to use for the draft model
gpt_params params_draft = params;
params_draft.model = params_draft.model_draft;
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
params_draft.n_world = 1; // do not split the draft model across devices
params_draft.rank = 0; // always load the draft model on the head device
std::fill_n(params_draft.n_layer_window, params.n_world, 0);
if (params_draft.draft_cpuparams.n_threads > 0) {
params_draft.cpuparams.n_threads = params_draft.draft_cpuparams.n_threads;
}
params_draft.cpuparams_batch.n_threads = params_draft.draft_cpuparams_batch.n_threads;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_draft);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
// load the target model
llama_init_result llama_init_tgt = llama_init_from_gpt_params(params);
model_tgt = llama_init_tgt.model;
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.draft_cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
llama_init_result llama_init_dft = llama_init_from_gpt_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
@ -161,9 +167,6 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
@ -180,8 +183,6 @@ int main(int argc, char ** argv) {
// target model sampling context (reuse the llama_context's sampling instance)
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
struct llama_sampler * softmax = llama_sampler_init_softmax();
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
@ -258,10 +259,13 @@ int main(int argc, char ** argv) {
float r = u_dist(rng);
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// if (dist_tgt.size > dist_dft.size) {
// LOG_ERR("dist_tgt.size (%zu) must be less than or equal to dist_dft.size (%zu)\n", dist_tgt.size, dist_dft.size);
// GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// }
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
for (size_t i = 0; i < dist_tgt.size && i < dist_dft.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
p_tgt = dist_tgt.data[i].p;
}
@ -406,7 +410,6 @@ int main(int argc, char ** argv) {
{
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
// TODO: simplify
{
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
@ -418,6 +421,12 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
// notify other devices to manage the KV cache in the same way
llama_send_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_send_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_send_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@ -435,7 +444,6 @@ int main(int argc, char ** argv) {
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
++n_past_dft;
@ -575,12 +583,13 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_kv_cache_seq_keep (ctx_tgt, 0);
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_kv_cache_seq_cp (ctx_tgt, 0, s, -1, -1);
llama_send_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
++n_past_tgt;
}
@ -612,7 +621,7 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("draft:\n\n");
// TODO: print sampling/grammar timings for all drafts
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
@ -624,7 +633,6 @@ int main(int argc, char ** argv) {
gpt_sampler_free(drafts[s].smpl);
}
llama_sampler_free(softmax);
llama_batch_free(batch_dft);
llama_free(ctx_tgt);

View file

@ -759,6 +759,11 @@ extern "C" {
LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
// Notify other nodes to keep only the specified sequence in their KV cache
LLAMA_API void llama_send_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:

View file

@ -17841,6 +17841,9 @@ struct sync_meta {
llama_pos cp_p0 = 0;
llama_pos cp_p1 = 0;
bool kv_seq_keep = false;
llama_seq_id keep_seq_id = 0;
// signal to divide the kv cache range
bool kv_seq_div = false;
llama_seq_id div_seq_id = 0;
@ -17943,8 +17946,14 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
return 0;
}
if (cmd == "kv_seq_keep" && recv_msgs.size() == 2) {
meta->kv_seq_keep = true;
std::memcpy(&meta->keep_seq_id, recv_msgs[idx++].data(), sizeof(meta->keep_seq_id));
return 0;
}
if (cmd == "kv_seq_div" && recv_msgs.size() == 5) {
meta->kv_seq_div = true;
meta->kv_seq_div = true;
std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id));
std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0));
std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1));
@ -18331,6 +18340,14 @@ static int llama_decode_internal(
return -1;
}
if (kv_cache_op(meta.kv_seq_keep,
[&]{ llama_kv_cache_seq_keep (&lctx, meta.keep_seq_id); },
[&]{ llama_send_kv_cache_seq_keep(&lctx, meta.keep_seq_id); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_keep\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_div,
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
@ -22349,6 +22366,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
}
void llama_send_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_keep", strlen("kv_seq_keep"));
msgs.emplace_back(&seq_id, sizeof(seq_id));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_keep: %s\n", e.what());
}
}
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (delta == 0) {
return;