mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 20:19:02 +00:00
add batch_all.logits to sync_meta
This commit is contained in:
parent
500e066a2f
commit
68ecc8509d
1 changed files with 19 additions and 0 deletions
|
@ -17810,6 +17810,7 @@ struct sync_meta {
|
||||||
llama_pos all_pos_0;
|
llama_pos all_pos_0;
|
||||||
llama_pos all_pos_1;
|
llama_pos all_pos_1;
|
||||||
uint32_t n_ctx = 0;
|
uint32_t n_ctx = 0;
|
||||||
|
int8_t * logits = nullptr;
|
||||||
|
|
||||||
// signal to clear the kv cache
|
// signal to clear the kv cache
|
||||||
bool clear_kv_cache = false;
|
bool clear_kv_cache = false;
|
||||||
|
@ -17862,6 +17863,12 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
||||||
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
||||||
|
|
||||||
|
if (meta->logits != nullptr) {
|
||||||
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
|
send_msgs.emplace_back("logits", strlen("logits"));
|
||||||
|
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
|
|
||||||
zmq::send_multipart(socket, send_msgs);
|
zmq::send_multipart(socket, send_msgs);
|
||||||
} catch (const zmq::error_t& e) {
|
} catch (const zmq::error_t& e) {
|
||||||
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
LLAMA_LOG_INFO("Failed to send meta data: %s\n", e.what());
|
||||||
|
@ -17944,6 +17951,13 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
|
||||||
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
|
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (key == "logits") {
|
||||||
|
GGML_ASSERT(meta->n_tokens > 0);
|
||||||
|
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
|
||||||
|
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
|
||||||
|
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -18225,6 +18239,10 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
batch_all.all_pos_0 = meta.all_pos_0;
|
batch_all.all_pos_0 = meta.all_pos_0;
|
||||||
batch_all.all_pos_1 = meta.all_pos_1;
|
batch_all.all_pos_1 = meta.all_pos_1;
|
||||||
|
if (meta.logits != nullptr) {
|
||||||
|
batch_all.logits = (int8_t *) malloc(meta.n_tokens * sizeof(int8_t));
|
||||||
|
std::memcpy(batch_all.logits, meta.logits, meta.n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_cache_op(meta.clear_kv_cache,
|
if (kv_cache_op(meta.clear_kv_cache,
|
||||||
|
@ -18273,6 +18291,7 @@ static int llama_decode_internal(
|
||||||
meta.pos = batch_all.pos;
|
meta.pos = batch_all.pos;
|
||||||
meta.all_pos_0 = batch_all.all_pos_0;
|
meta.all_pos_0 = batch_all.all_pos_0;
|
||||||
meta.all_pos_1 = batch_all.all_pos_1;
|
meta.all_pos_1 = batch_all.all_pos_1;
|
||||||
|
meta.logits = batch_all.logits;
|
||||||
llama_send_meta(*lctx.send_socket, &meta);
|
llama_send_meta(*lctx.send_socket, &meta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue