diff --git a/src/llama.cpp b/src/llama.cpp index 7895b8f6..9b1d2a11 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17805,13 +17805,16 @@ struct input_tensors { }; struct sync_meta { - int32_t n_tokens = 0; - llama_pos * pos = nullptr; + int32_t n_tokens = 0; + llama_pos * pos = nullptr; + int32_t * n_seq_id = nullptr; + llama_seq_id ** seq_id = nullptr; + int8_t * logits = nullptr; + llama_pos all_pos_0; llama_pos all_pos_1; - uint32_t n_ctx = 0; - int8_t * logits = nullptr; - + uint32_t n_ctx = 0; + // signal to clear the kv cache bool clear_kv_cache = false; @@ -17857,11 +17860,19 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos)); } - send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); - send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + if (meta->n_seq_id != nullptr) { + GGML_ASSERT(meta->seq_id != nullptr); + send_msgs.emplace_back("n_seq_id", strlen("n_seq_id")); + send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t)); - send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); - send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + for (size_t i = 0; i < meta->n_ctx; ++i) { + const size_t n_seq = meta->n_seq_id[i]; + if (n_seq > 0) { + send_msgs.emplace_back("seq_id", strlen("seq_id")); + send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id)); + } + } + } if (meta->logits != nullptr) { GGML_ASSERT(meta->n_tokens > 0); @@ -17869,6 +17880,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) { send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t)); } + send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); + send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + + send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); + send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + zmq::send_multipart(socket, send_msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what()); @@ -17928,6 +17945,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { return 0; } + size_t seq_idx = 0; for (size_t i = 0; i < recv_msgs.size(); i += 2) { std::string key = recv_msgs[i].to_string(); zmq::message_t & data_msg = recv_msgs[i + 1]; @@ -17942,6 +17960,41 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos)); } + if (key == "n_seq_id") { + GGML_ASSERT(meta->n_ctx > 0); + GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t)); + meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t)); + memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t)); + + meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *)); + for (size_t j = 0; j < meta->n_ctx; ++j) { + meta->seq_id[j] = nullptr; + } + seq_idx = 0; + } + + if (key == "seq_id") { + while (seq_idx < meta->n_ctx && meta->n_seq_id[seq_idx] <= 0) { + ++seq_idx; + } + GGML_ASSERT(seq_idx < meta->n_ctx); + + const size_t n_seq = meta->n_seq_id[seq_idx]; + GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id)); + + meta->seq_id[seq_idx] = (llama_seq_id *) malloc(n_seq * sizeof(llama_seq_id)); + memcpy(meta->seq_id[seq_idx], data_msg.data(), n_seq * sizeof(llama_seq_id)); + + ++seq_idx; + } + + if (key == "logits") { + GGML_ASSERT(meta->n_tokens > 0); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); + meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); + std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); + } + if (key == "all_pos_0") { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0)); std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0)); @@ -17951,13 +18004,6 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); } - - if (key == "logits") { - GGML_ASSERT(meta->n_tokens > 0); - GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); - meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); - std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); - } } return 0; } @@ -18237,12 +18283,23 @@ static int llama_decode_internal( batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); } - batch_all.all_pos_0 = meta.all_pos_0; - batch_all.all_pos_1 = meta.all_pos_1; + if (meta.n_seq_id != nullptr) { + batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t)); + std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t)); + } + if (meta.seq_id != nullptr) { + batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *)); + for (size_t i = 0; i < cparams.n_ctx; ++i) { + batch_all.seq_id[i] = (llama_seq_id *) malloc(meta.n_seq_id[i] * sizeof(llama_seq_id)); + std::memcpy(batch_all.seq_id[i], meta.seq_id[i], meta.n_seq_id[i] * sizeof(llama_seq_id)); + } + } if (meta.logits != nullptr) { batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t)); std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); } + batch_all.all_pos_0 = meta.all_pos_0; + batch_all.all_pos_1 = meta.all_pos_1; } if (kv_cache_op(meta.clear_kv_cache, @@ -18289,9 +18346,11 @@ static int llama_decode_internal( if (!is_last_dev) { meta.n_tokens = batch_all.n_tokens; meta.pos = batch_all.pos; + meta.n_seq_id = batch_all.n_seq_id; + meta.seq_id = batch_all.seq_id; + meta.logits = batch_all.logits; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; - meta.logits = batch_all.logits; llama_send_meta(*lctx.send_socket, &meta); }