diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 39d4b60c..ccff70f2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -348,6 +348,9 @@ int main(int argc, char ** argv) { // remove any "future" tokens that we might have inherited from the previous session llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); + if (my_rank == 0) { + llama_send_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); + } } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", @@ -593,6 +596,11 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + if (my_rank == 0) { + llama_send_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_send_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + } + n_past -= n_discard; LOG_DBG("after swap: n_past = %d\n", n_past); diff --git a/src/llama.cpp b/src/llama.cpp index 1f68ced0..cd5a95b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17788,6 +17788,8 @@ struct input_tensors { struct sync_meta { int32_t n_tokens = 0; llama_pos * pos = nullptr; + llama_pos all_pos_0; + llama_pos all_pos_1; uint32_t n_ctx = 0; // signal to clear the kv cache @@ -17835,6 +17837,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos)); } + send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); + send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + + send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); + send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + zmq::send_multipart(socket, send_msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what()); @@ -17907,6 +17915,16 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos)); std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos)); } + + if (key == "all_pos_0") { + GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0)); + std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0)); + } + + if (key == "all_pos_1") { + GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); + std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); + } } return 0; } @@ -18185,6 +18203,8 @@ static int llama_decode_internal( batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); } + batch_all.all_pos_0 = meta.all_pos_0; + batch_all.all_pos_1 = meta.all_pos_1; } if (kv_cache_op(meta.clear_kv_cache, @@ -18229,8 +18249,10 @@ static int llama_decode_internal( } if (!is_last_dev) { - meta.n_tokens = batch_all.n_tokens; - meta.pos = batch_all.pos; + meta.n_tokens = batch_all.n_tokens; + meta.pos = batch_all.pos; + meta.all_pos_0 = batch_all.all_pos_0; + meta.all_pos_1 = batch_all.all_pos_1; llama_send_meta(*lctx.send_socket, &meta); }