mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 15:29:04 +00:00
Enable distributed model perplexity measurement for different bit-width models with -lw and -ngl parameters
This commit is contained in:
parent
48b7f53abb
commit
82787be7eb
2 changed files with 22 additions and 9 deletions
|
@ -533,8 +533,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
llama_send_meta(ctx, &meta);
|
llama_send_meta(ctx, &meta);
|
||||||
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
|
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: rank %d waiting 5 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
|
LOG_INF("%s: rank %d waiting 7 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
std::this_thread::sleep_for(std::chrono::milliseconds(7000));
|
||||||
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
|
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
|
||||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||||
return { {}, -1.0, {}, {} };
|
return { {}, -1.0, {}, {} };
|
||||||
|
|
|
@ -17892,10 +17892,8 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
|
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
|
||||||
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
|
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
|
||||||
|
|
||||||
if (meta->n_chunks >= 0) {
|
send_msgs.emplace_back("n_chunks", strlen("n_chunks"));
|
||||||
send_msgs.emplace_back("n_chunks", strlen("n_chunks"));
|
send_msgs.emplace_back(&(meta->n_chunks), sizeof(meta->n_chunks));
|
||||||
send_msgs.emplace_back(&(meta->n_chunks), sizeof(meta->n_chunks));
|
|
||||||
}
|
|
||||||
|
|
||||||
zmq::send_multipart(*send_socket, send_msgs);
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
return;
|
return;
|
||||||
|
@ -18015,6 +18013,11 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
|
|
||||||
recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
||||||
|
|
||||||
|
if (recv_msgs.size() < 2) {
|
||||||
|
LLAMA_LOG_ERROR("Invalid message format: too few messages\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
const std::string cmd = recv_msgs[0].to_string();
|
const std::string cmd = recv_msgs[0].to_string();
|
||||||
size_t idx = 1;
|
size_t idx = 1;
|
||||||
|
|
||||||
|
@ -18023,8 +18026,6 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
|
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
|
||||||
meta->kv_seq_rm = true;
|
meta->kv_seq_rm = true;
|
||||||
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
|
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
|
||||||
|
@ -18060,6 +18061,19 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cmd == "tokens_size" && recv_msgs.size() == 4) {
|
||||||
|
std::memcpy(&(meta->tokens_size), recv_msgs[1].data(), sizeof(meta->tokens_size));
|
||||||
|
|
||||||
|
std::string chunks_key = recv_msgs[2].to_string();
|
||||||
|
if (chunks_key == "n_chunks") {
|
||||||
|
std::memcpy(&(meta->n_chunks), recv_msgs[3].data(), sizeof(meta->n_chunks));
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_ERROR("Expected 'n_chunks' key but got '%s'\n", chunks_key.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (recv_msgs.size() % 2 != 0) {
|
if (recv_msgs.size() % 2 != 0) {
|
||||||
LLAMA_LOG_ERROR("Invalid message format: odd number of messages\n");
|
LLAMA_LOG_ERROR("Invalid message format: odd number of messages\n");
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -18357,7 +18371,6 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
GGML_ASSERT(!(my_rank == 0 && n_tokens_all == 0) && "n_tokens == 0 on master node");
|
GGML_ASSERT(!(my_rank == 0 && n_tokens_all == 0) && "n_tokens == 0 on master node");
|
||||||
|
|
||||||
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
|
|
||||||
if (batch_all.token) {
|
if (batch_all.token) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
|
if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue