From dc875bbef9d832e30cc910f44f89e85bd0aeb84a Mon Sep 17 00:00:00 2001 From: "Li, Zonghang" <870644199@qq.com> Date: Fri, 13 Jun 2025 08:18:12 +0400 Subject: [PATCH] fix speculative decoding --- Makefile | 2 + examples/speculative/speculative.cpp | 62 ++++++++++++++++------------ include/llama.h | 5 +++ src/llama.cpp | 34 ++++++++++++++- 4 files changed, 75 insertions(+), 28 deletions(-) diff --git a/Makefile b/Makefile index 06d91984..8d9f7410 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,8 @@ BUILD_TARGETS = \ llama-server \ llama-cli \ + llama-speculative \ + llama-gguf-split \ profile-tool # BUILD_TARGETS = \ diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index adf6255e..3716579e 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -12,7 +12,7 @@ #include #include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct seq_draft { @@ -65,23 +65,29 @@ int main(int argc, char ** argv) { llama_context * ctx_tgt = NULL; llama_context * ctx_dft = NULL; + // load the draft model + // make a hard copy of params to use for the draft model + gpt_params params_draft = params; + params_draft.model = params_draft.model_draft; + params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft; + params_draft.n_world = 1; // do not split the draft model across devices + params_draft.rank = 0; // always load the draft model on the head device + std::fill_n(params_draft.n_layer_window, params.n_world, 0); + + if (params_draft.draft_cpuparams.n_threads > 0) { + params_draft.cpuparams.n_threads = params_draft.draft_cpuparams.n_threads; + } + + params_draft.cpuparams_batch.n_threads = params_draft.draft_cpuparams_batch.n_threads; + llama_init_result llama_init_dft = llama_init_from_gpt_params(params_draft); + model_dft = llama_init_dft.model; + ctx_dft = llama_init_dft.context; + // load the target model llama_init_result llama_init_tgt = llama_init_from_gpt_params(params); model_tgt = llama_init_tgt.model; ctx_tgt = llama_init_tgt.context; - // load the draft model - params.model = params.model_draft; - params.n_gpu_layers = params.n_gpu_layers_draft; - if (params.draft_cpuparams.n_threads > 0) { - params.cpuparams.n_threads = params.draft_cpuparams.n_threads; - } - - params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; - llama_init_result llama_init_dft = llama_init_from_gpt_params(params); - model_dft = llama_init_dft.model; - ctx_dft = llama_init_dft.context; - const bool vocab_type_tgt = llama_vocab_type(model_tgt); LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); @@ -161,9 +167,6 @@ int main(int argc, char ** argv) { const auto t_enc_end = ggml_time_us(); - // the 2 models should have the same vocab - //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); - // how many tokens to draft each time int n_draft = params.n_draft; @@ -180,8 +183,6 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); - struct llama_sampler * softmax = llama_sampler_init_softmax(); - // draft sequence data std::vector drafts(n_seq_dft); @@ -258,10 +259,13 @@ int main(int argc, char ** argv) { float r = u_dist(rng); llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true }; - //GGML_ASSERT(dist_tgt.size <= dist_dft.size); + // if (dist_tgt.size > dist_dft.size) { + // LOG_ERR("dist_tgt.size (%zu) must be less than or equal to dist_dft.size (%zu)\n", dist_tgt.size, dist_dft.size); + // GGML_ASSERT(dist_tgt.size <= dist_dft.size); + // } // acquire the token probabilities assigned by the draft and target models - for (size_t i = 0; i < dist_tgt.size; i++) { + for (size_t i = 0; i < dist_tgt.size && i < dist_dft.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { p_tgt = dist_tgt.data[i].p; } @@ -406,7 +410,6 @@ int main(int argc, char ** argv) { { LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); - // TODO: simplify { LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); @@ -418,6 +421,12 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_keep(ctx_tgt, s_keep); llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); llama_kv_cache_seq_keep(ctx_tgt, 0); + + // notify other devices to manage the KV cache in the same way + llama_send_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_send_kv_cache_seq_keep(ctx_tgt, s_keep); + llama_send_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_send_kv_cache_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -435,7 +444,6 @@ int main(int argc, char ** argv) { llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); ++n_past_dft; @@ -575,12 +583,13 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_keep (ctx_tgt, 0); + llama_send_kv_cache_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_kv_cache_seq_cp (ctx_tgt, 0, s, -1, -1); + llama_send_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); } - // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); ++n_past_tgt; } @@ -612,7 +621,7 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("draft:\n\n"); - // TODO: print sampling/grammar timings for all drafts + llama_perf_context_print(ctx_dft); LOG_INF("\n"); @@ -624,7 +633,6 @@ int main(int argc, char ** argv) { gpt_sampler_free(drafts[s].smpl); } - llama_sampler_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/include/llama.h b/include/llama.h index 86da593c..4c39b063 100644 --- a/include/llama.h +++ b/include/llama.h @@ -759,6 +759,11 @@ extern "C" { LLAMA_API void llama_kv_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); + + // Notify other nodes to keep only the specified sequence in their KV cache + LLAMA_API void llama_send_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: diff --git a/src/llama.cpp b/src/llama.cpp index af42f79d..7cd74983 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17841,6 +17841,9 @@ struct sync_meta { llama_pos cp_p0 = 0; llama_pos cp_p1 = 0; + bool kv_seq_keep = false; + llama_seq_id keep_seq_id = 0; + // signal to divide the kv cache range bool kv_seq_div = false; llama_seq_id div_seq_id = 0; @@ -17943,8 +17946,14 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { return 0; } + if (cmd == "kv_seq_keep" && recv_msgs.size() == 2) { + meta->kv_seq_keep = true; + std::memcpy(&meta->keep_seq_id, recv_msgs[idx++].data(), sizeof(meta->keep_seq_id)); + return 0; + } + if (cmd == "kv_seq_div" && recv_msgs.size() == 5) { - meta->kv_seq_div = true; + meta->kv_seq_div = true; std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id)); std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0)); std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1)); @@ -18331,6 +18340,14 @@ static int llama_decode_internal( return -1; } + if (kv_cache_op(meta.kv_seq_keep, + [&]{ llama_kv_cache_seq_keep (&lctx, meta.keep_seq_id); }, + [&]{ llama_send_kv_cache_seq_keep(&lctx, meta.keep_seq_id); }, + is_last_dev)) { + LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_keep\n", __func__); + return -1; + } + if (kv_cache_op(meta.kv_seq_div, [&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); }, [&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); }, @@ -22349,6 +22366,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { llama_kv_cache_seq_keep(ctx->kv_self, seq_id); } +void llama_send_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + if (ctx->send_socket == nullptr) { + return; + } + + try { + std::vector msgs; + msgs.emplace_back("kv_seq_keep", strlen("kv_seq_keep")); + msgs.emplace_back(&seq_id, sizeof(seq_id)); + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t & e) { + LLAMA_LOG_WARN("Failed to send kv_seq_keep: %s\n", e.what()); + } +} + void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (delta == 0) { return;