fix speculative decoding

This commit is contained in:
Li, Zonghang 2025-06-13 08:18:12 +04:00
parent e50b3aa473
commit dc875bbef9
4 changed files with 75 additions and 28 deletions

View file

@ -17841,6 +17841,9 @@ struct sync_meta {
llama_pos cp_p0 = 0;
llama_pos cp_p1 = 0;
bool kv_seq_keep = false;
llama_seq_id keep_seq_id = 0;
// signal to divide the kv cache range
bool kv_seq_div = false;
llama_seq_id div_seq_id = 0;
@ -17943,8 +17946,14 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
return 0;
}
if (cmd == "kv_seq_keep" && recv_msgs.size() == 2) {
meta->kv_seq_keep = true;
std::memcpy(&meta->keep_seq_id, recv_msgs[idx++].data(), sizeof(meta->keep_seq_id));
return 0;
}
if (cmd == "kv_seq_div" && recv_msgs.size() == 5) {
meta->kv_seq_div = true;
meta->kv_seq_div = true;
std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id));
std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0));
std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1));
@ -18331,6 +18340,14 @@ static int llama_decode_internal(
return -1;
}
if (kv_cache_op(meta.kv_seq_keep,
[&]{ llama_kv_cache_seq_keep (&lctx, meta.keep_seq_id); },
[&]{ llama_send_kv_cache_seq_keep(&lctx, meta.keep_seq_id); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_keep\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_div,
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
@ -22349,6 +22366,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
}
void llama_send_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
if (ctx->send_socket == nullptr) {
return;
}
try {
std::vector<zmq::message_t> msgs;
msgs.emplace_back("kv_seq_keep", strlen("kv_seq_keep"));
msgs.emplace_back(&seq_id, sizeof(seq_id));
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t & e) {
LLAMA_LOG_WARN("Failed to send kv_seq_keep: %s\n", e.what());
}
}
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (delta == 0) {
return;