mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-10 17:14:37 +00:00
fix llama-cli pos sync
This commit is contained in:
parent
c54a6a0132
commit
421b3deca5
2 changed files with 32 additions and 2 deletions
|
@ -348,6 +348,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// remove any "future" tokens that we might have inherited from the previous session
|
// remove any "future" tokens that we might have inherited from the previous session
|
||||||
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||||
|
if (my_rank == 0) {
|
||||||
|
llama_send_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
|
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
|
||||||
|
@ -593,6 +596,11 @@ int main(int argc, char ** argv) {
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
|
if (my_rank == 0) {
|
||||||
|
llama_send_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||||
|
llama_send_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||||
|
}
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past -= n_discard;
|
||||||
|
|
||||||
LOG_DBG("after swap: n_past = %d\n", n_past);
|
LOG_DBG("after swap: n_past = %d\n", n_past);
|
||||||
|
|
|
@ -17788,6 +17788,8 @@ struct input_tensors {
|
||||||
struct sync_meta {
|
struct sync_meta {
|
||||||
int32_t n_tokens = 0;
|
int32_t n_tokens = 0;
|
||||||
llama_pos * pos = nullptr;
|
llama_pos * pos = nullptr;
|
||||||
|
llama_pos all_pos_0;
|
||||||
|
llama_pos all_pos_1;
|
||||||
uint32_t n_ctx = 0;
|
uint32_t n_ctx = 0;
|
||||||
|
|
||||||
// signal to clear the kv cache
|
// signal to clear the kv cache
|
||||||
|
@ -17835,6 +17837,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
||||||
|
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
||||||
|
|
||||||
|
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
||||||
|
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
||||||
|
|
||||||
zmq::send_multipart(socket, send_msgs);
|
zmq::send_multipart(socket, send_msgs);
|
||||||
} catch (const zmq::error_t& e) {
|
} catch (const zmq::error_t& e) {
|
||||||
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
||||||
|
@ -17907,6 +17915,16 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos));
|
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos));
|
||||||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (key == "all_pos_0") {
|
||||||
|
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
||||||
|
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (key == "all_pos_1") {
|
||||||
|
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
|
||||||
|
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -18185,6 +18203,8 @@ static int llama_decode_internal(
|
||||||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
||||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
batch_all.all_pos_0 = meta.all_pos_0;
|
||||||
|
batch_all.all_pos_1 = meta.all_pos_1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_cache_op(meta.clear_kv_cache,
|
if (kv_cache_op(meta.clear_kv_cache,
|
||||||
|
@ -18229,8 +18249,10 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!is_last_dev) {
|
if (!is_last_dev) {
|
||||||
meta.n_tokens = batch_all.n_tokens;
|
meta.n_tokens = batch_all.n_tokens;
|
||||||
meta.pos = batch_all.pos;
|
meta.pos = batch_all.pos;
|
||||||
|
meta.all_pos_0 = batch_all.all_pos_0;
|
||||||
|
meta.all_pos_1 = batch_all.all_pos_1;
|
||||||
llama_send_meta(*lctx.send_socket, &meta);
|
llama_send_meta(*lctx.send_socket, &meta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue