fix n_seq_id and seq_id

This commit is contained in:
Lizonghang 2025-06-06 23:58:03 +04:00
parent a1a2238831
commit d8aea899d1

View file

@ -17861,16 +17861,14 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
}
if (meta->n_seq_id != nullptr) {
GGML_ASSERT(meta->seq_id != nullptr);
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
send_msgs.emplace_back(meta->n_seq_id, meta->n_ctx * sizeof(int32_t));
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
for (size_t i = 0; i < meta->n_ctx; ++i) {
for (int32_t i = 0; i < meta->n_tokens; ++i) {
GGML_ASSERT(meta->seq_id[i] != nullptr);
const size_t n_seq = meta->n_seq_id[i];
if (n_seq > 0) {
send_msgs.emplace_back("seq_id", strlen("seq_id"));
send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id));
}
send_msgs.emplace_back("seq_id", strlen("seq_id"));
send_msgs.emplace_back(meta->seq_id[i], n_seq * sizeof(llama_seq_id));
}
}
@ -17945,7 +17943,7 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
return 0;
}
size_t seq_idx = 0;
int32_t seq_idx = 0;
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
std::string key = recv_msgs[i].to_string();
zmq::message_t & data_msg = recv_msgs[i + 1];
@ -17961,23 +17959,22 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
}
if (key == "n_seq_id") {
GGML_ASSERT(meta->n_ctx > 0);
GGML_ASSERT(data_msg.size() == meta->n_ctx * sizeof(int32_t));
meta->n_seq_id = (int32_t *) malloc(meta->n_ctx * sizeof(int32_t));
memcpy(meta->n_seq_id, data_msg.data(), meta->n_ctx * sizeof(int32_t));
GGML_ASSERT(meta->n_tokens > 0);
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int32_t));
meta->n_seq_id = (int32_t *) malloc(meta->n_tokens * sizeof(int32_t));
memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
meta->seq_id = (llama_seq_id **) malloc(meta->n_ctx * sizeof(llama_seq_id *));
for (size_t j = 0; j < meta->n_ctx; ++j) {
meta->seq_id = (llama_seq_id **) malloc(meta->n_tokens * sizeof(llama_seq_id *));
for (int32_t j = 0; j < meta->n_tokens; ++j) {
meta->seq_id[j] = nullptr;
}
seq_idx = 0;
}
if (key == "seq_id") {
while (seq_idx < meta->n_ctx && meta->n_seq_id[seq_idx] <= 0) {
while (seq_idx < meta->n_tokens && meta->n_seq_id[seq_idx] <= 0) {
++seq_idx;
}
GGML_ASSERT(seq_idx < meta->n_ctx);
const size_t n_seq = meta->n_seq_id[seq_idx];
GGML_ASSERT(data_msg.size() == n_seq * sizeof(llama_seq_id));