mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 00:29:02 +00:00
add batch_all.n_seq_id and batch_all.seq_id to sync_meta
This commit is contained in:
parent
68ecc8509d
commit
a1a2238831
1 changed files with 78 additions and 19 deletions
|
@ -17805,13 +17805,16 @@ struct input_tensors {
|
|||
};
|
||||
|
||||
struct sync_meta {
|
||||
int32_t n_tokens = 0;
|
||||
llama_pos * pos = nullptr;
|
||||
int32_t n_tokens = 0;
|
||||
llama_pos * pos = nullptr;
|
||||
int32_t * n_seq_id = nullptr;
|
||||
llama_seq_id ** seq_id = nullptr;
|
||||
int8_t * logits = nullptr;
|
||||
|
||||
llama_pos all_pos_0;
|
||||
llama_pos all_pos_1;
|
||||
uint32_t n_ctx = 0;
|
||||
int8_t * logits = nullptr;
|
||||
|
||||
uint32_t n_ctx = 0;
|
||||
|
||||
// signal to clear the kv cache
|
||||
bool clear_kv_cache = false;
|
||||
|
||||
|
@ -17857,11 +17860,19 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
|
||||
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
||||
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
||||
if (meta->n_seq_id != nullptr) {
|
||||
GGML_ASSERT(meta->seq_id != nullptr);
|
||||
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
|
||||
|
||||
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
||||
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
||||
for (size_t i = 0; i < meta->n_ctx; ++i) {
|
||||
const size_t n_seq = meta->n_seq_id[i];
|
||||
if (n_seq > 0) {
|
||||
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||
send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (meta->logits != nullptr) {
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
|
@ -17869,6 +17880,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
|
||||
}
|
||||
|
||||
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
||||
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
||||
|
||||
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
||||
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
||||
|
||||
zmq::send_multipart(socket, send_msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
||||
|
@ -17928,6 +17945,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
size_t seq_idx = 0;
|
||||
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
|
||||
std::string key = recv_msgs[i].to_string();
|
||||
zmq::message_t & data_msg = recv_msgs[i + 1];
|
||||
|
@ -17942,6 +17960,41 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
|
||||
if (key == "n_seq_id") {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
|
||||
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
|
||||
memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
|
||||
|
||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
|
||||
for (size_t j = 0; j < meta->n_ctx; ++j) {
|
||||
meta->seq_id[j] = nullptr;
|
||||
}
|
||||
seq_idx = 0;
|
||||
}
|
||||
|
||||
if (key == "seq_id") {
|
||||
while (seq_idx < meta->n_ctx && meta->n_seq_id[seq_idx] <= 0) {
|
||||
++seq_idx;
|
||||
}
|
||||
GGML_ASSERT(seq_idx < meta->n_ctx);
|
||||
|
||||
const size_t n_seq = meta->n_seq_id[seq_idx];
|
||||
GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id));
|
||||
|
||||
meta->seq_id[seq_idx] = (llama_seq_id *) malloc(n_seq * sizeof(llama_seq_id));
|
||||
memcpy(meta->seq_id[seq_idx], data_msg.data(), n_seq * sizeof(llama_seq_id));
|
||||
|
||||
++seq_idx;
|
||||
}
|
||||
|
||||
if (key == "logits") {
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
|
||||
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
|
||||
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
|
||||
}
|
||||
|
||||
if (key == "all_pos_0") {
|
||||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
||||
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
||||
|
@ -17951,13 +18004,6 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
|
||||
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
|
||||
}
|
||||
|
||||
if (key == "logits") {
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
|
||||
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
|
||||
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -18237,12 +18283,23 @@ static int llama_decode_internal(
|
|||
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
|
||||
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
|
||||
}
|
||||
batch_all.all_pos_0 = meta.all_pos_0;
|
||||
batch_all.all_pos_1 = meta.all_pos_1;
|
||||
if (meta.n_seq_id != nullptr) {
|
||||
batch_all.n_seq_id = (int32_t *) malloc(cparams.n_ctx * sizeof(int32_t));
|
||||
std::memcpy(batch_all.n_seq_id, meta.n_seq_id, cparams.n_ctx * sizeof(int32_t));
|
||||
}
|
||||
if (meta.seq_id != nullptr) {
|
||||
batch_all.seq_id = (llama_seq_id **) malloc(cparams.n_ctx * sizeof(llama_seq_id *));
|
||||
for (size_t i = 0; i < cparams.n_ctx; ++i) {
|
||||
batch_all.seq_id[i] = (llama_seq_id *) malloc(meta.n_seq_id[i] * sizeof(llama_seq_id));
|
||||
std::memcpy(batch_all.seq_id[i], meta.seq_id[i], meta.n_seq_id[i] * sizeof(llama_seq_id));
|
||||
}
|
||||
}
|
||||
if (meta.logits != nullptr) {
|
||||
batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t));
|
||||
std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t));
|
||||
}
|
||||
batch_all.all_pos_0 = meta.all_pos_0;
|
||||
batch_all.all_pos_1 = meta.all_pos_1;
|
||||
}
|
||||
|
||||
if (kv_cache_op(meta.clear_kv_cache,
|
||||
|
@ -18289,9 +18346,11 @@ static int llama_decode_internal(
|
|||
if (!is_last_dev) {
|
||||
meta.n_tokens = batch_all.n_tokens;
|
||||
meta.pos = batch_all.pos;
|
||||
meta.n_seq_id = batch_all.n_seq_id;
|
||||
meta.seq_id = batch_all.seq_id;
|
||||
meta.logits = batch_all.logits;
|
||||
meta.all_pos_0 = batch_all.all_pos_0;
|
||||
meta.all_pos_1 = batch_all.all_pos_1;
|
||||
meta.logits = batch_all.logits;
|
||||
llama_send_meta(*lctx.send_socket, &meta);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue