mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-10 16:44:40 +00:00
fix speculative decoding
This commit is contained in:
parent
e50b3aa473
commit
dc875bbef9
4 changed files with 75 additions and 28 deletions
|
@ -17841,6 +17841,9 @@ struct sync_meta {
|
|||
llama_pos cp_p0 = 0;
|
||||
llama_pos cp_p1 = 0;
|
||||
|
||||
bool kv_seq_keep = false;
|
||||
llama_seq_id keep_seq_id = 0;
|
||||
|
||||
// signal to divide the kv cache range
|
||||
bool kv_seq_div = false;
|
||||
llama_seq_id div_seq_id = 0;
|
||||
|
@ -17943,8 +17946,14 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
if (cmd == "kv_seq_keep" && recv_msgs.size() == 2) {
|
||||
meta->kv_seq_keep = true;
|
||||
std::memcpy(&meta->keep_seq_id, recv_msgs[idx++].data(), sizeof(meta->keep_seq_id));
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (cmd == "kv_seq_div" && recv_msgs.size() == 5) {
|
||||
meta->kv_seq_div = true;
|
||||
meta->kv_seq_div = true;
|
||||
std::memcpy(&meta->div_seq_id, recv_msgs[idx++].data(), sizeof(meta->div_seq_id));
|
||||
std::memcpy(&meta->div_p0, recv_msgs[idx++].data(), sizeof(meta->div_p0));
|
||||
std::memcpy(&meta->div_p1, recv_msgs[idx++].data(), sizeof(meta->div_p1));
|
||||
|
@ -18331,6 +18340,14 @@ static int llama_decode_internal(
|
|||
return -1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.kv_seq_keep,
|
||||
[&]{ llama_kv_cache_seq_keep (&lctx, meta.keep_seq_id); },
|
||||
[&]{ llama_send_kv_cache_seq_keep(&lctx, meta.keep_seq_id); },
|
||||
is_last_dev)) {
|
||||
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_keep\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.kv_seq_div,
|
||||
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
|
||||
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
|
||||
|
@ -22349,6 +22366,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
|||
llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
void llama_send_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
if (ctx->send_socket == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
std::vector<zmq::message_t> msgs;
|
||||
msgs.emplace_back("kv_seq_keep", strlen("kv_seq_keep"));
|
||||
msgs.emplace_back(&seq_id, sizeof(seq_id));
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t & e) {
|
||||
LLAMA_LOG_WARN("Failed to send kv_seq_keep: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
if (delta == 0) {
|
||||
return;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue