mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 20:29:26 +00:00
assume only a single seq_id per token is needed
This commit is contained in:
parent
d8aea899d1
commit
e56be76bdf
1 changed files with 22 additions and 27 deletions
|
@ -17861,15 +17861,20 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
}
|
||||
|
||||
if (meta->n_seq_id != nullptr) {
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
|
||||
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
|
||||
|
||||
for (int32_t i = 0; i < meta->n_tokens; ++i) {
|
||||
GGML_ASSERT(meta->seq_id[i] != nullptr);
|
||||
const size_t n_seq = meta->n_seq_id[i];
|
||||
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||
send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id));
|
||||
// here we assume only a single seq_id per token is needed
|
||||
// pack all single seq_id values into a contiguous array
|
||||
llama_seq_id * all_seq_ids = (llama_seq_id *) malloc(meta->n_ctx * sizeof(llama_seq_id));
|
||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||
all_seq_ids[i] = meta->seq_id[i][0];
|
||||
}
|
||||
|
||||
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||
send_msgs.emplace_back(all_seq_ids, meta->n_ctx * sizeof(llama_seq_id));
|
||||
free(all_seq_ids);
|
||||
}
|
||||
|
||||
if (meta->logits != nullptr) {
|
||||
|
@ -17943,7 +17948,6 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int32_t seq_idx = 0;
|
||||
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
|
||||
std::string key = recv_msgs[i].to_string();
|
||||
zmq::message_t & data_msg = recv_msgs[i + 1];
|
||||
|
@ -17959,30 +17963,21 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
|||
}
|
||||
|
||||
if (key == "n_seq_id") {
|
||||
GGML_ASSERT(meta->n_tokens > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t));
|
||||
meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t));
|
||||
memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
|
||||
|
||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *));
|
||||
for (int32_t j = 0; j < meta->n_tokens; ++j) {
|
||||
meta->seq_id[j] = nullptr;
|
||||
}
|
||||
seq_idx = 0;
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
|
||||
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
|
||||
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
|
||||
}
|
||||
|
||||
if (key == "seq_id") {
|
||||
while (seq_idx < meta->n_tokens && meta->n_seq_id[seq_idx] <= 0) {
|
||||
++seq_idx;
|
||||
GGML_ASSERT(meta->n_ctx > 0);
|
||||
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(llama_seq_id));
|
||||
const llama_seq_id * all_seq_ids = (llama_seq_id *) data_msg.data();
|
||||
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
|
||||
for (uint32_t i = 0; i < meta->n_ctx; ++i) {
|
||||
meta->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
meta->seq_id[i][0] = all_seq_ids[i];
|
||||
}
|
||||
|
||||
const size_t n_seq = meta->n_seq_id[seq_idx];
|
||||
GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id));
|
||||
|
||||
meta->seq_id[seq_idx] = (llama_seq_id *) malloc(n_seq * sizeof(llama_seq_id));
|
||||
memcpy(meta->seq_id[seq_idx], data_msg.data(), n_seq * sizeof(llama_seq_id));
|
||||
|
||||
++seq_idx;
|
||||
}
|
||||
|
||||
if (key == "logits") {
|
||||
|
|
Loading…
Add table
Reference in a new issue