mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-10 19:34:56 +00:00
Removed some unnecessary synchronization logic and added n_chunks communication content
This commit is contained in:
parent
a3becb586a
commit
48b7f53abb
4 changed files with 97 additions and 218 deletions
|
@ -489,9 +489,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
||||
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
||||
|
||||
// only last device store logits file
|
||||
std::ofstream logits_stream;
|
||||
if (my_rank == n_world - 1){
|
||||
if (my_rank == 0) {
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
|
||||
if (!logits_stream.is_open()) {
|
||||
|
@ -506,9 +505,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
std::vector<llama_token> tokens;
|
||||
size_t tokens_size = 0;
|
||||
int n_chunks = params.n_chunks;
|
||||
|
||||
// maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes
|
||||
if (my_rank == 0 || is_last_dev) {
|
||||
if (my_rank == 0) {
|
||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank);
|
||||
|
||||
|
@ -519,26 +518,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
}
|
||||
|
||||
{
|
||||
if (n_world > 1) {
|
||||
sync_meta meta;
|
||||
if (my_rank == 0) {
|
||||
if (my_rank != 0) {
|
||||
LOG_INF("perplexity: rank %d waiting for rank 0 to be ready\n", my_rank);
|
||||
}
|
||||
|
||||
if (n_world > 1) {
|
||||
sync_meta meta;
|
||||
|
||||
if (my_rank == 0) {
|
||||
meta.tokens_size = tokens_size;
|
||||
llama_send_meta(ctx, &meta, false);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
||||
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
|
||||
meta.n_chunks = params.n_chunks;
|
||||
|
||||
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
|
||||
llama_send_meta(ctx, &meta);
|
||||
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
|
||||
} else {
|
||||
LOG_INF("%s: rank %d waiting 5 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
||||
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
return { {}, -1.0, {}, {} };
|
||||
}
|
||||
if (is_last_dev) {
|
||||
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
|
||||
} else {
|
||||
tokens_size = meta.tokens_size;
|
||||
llama_send_meta(ctx, &meta, false);
|
||||
}
|
||||
tokens_size = meta.tokens_size;
|
||||
n_chunks = meta.n_chunks;
|
||||
if (!is_last_dev) {
|
||||
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size);
|
||||
|
||||
if (my_rank == 0) {
|
||||
|
@ -553,14 +562,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
std::vector<float> logit_history;
|
||||
std::vector<float> prob_history;
|
||||
|
||||
if (is_last_dev) {
|
||||
if (my_rank == 0) {
|
||||
logit_history.resize(tokens_size);
|
||||
prob_history.resize(tokens_size);
|
||||
}
|
||||
|
||||
const int n_chunk_max = tokens.size() / n_ctx;
|
||||
const int n_chunk_max = tokens_size / n_ctx;
|
||||
|
||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||
const int n_chunk = n_chunks < 0 ? n_chunk_max : std::min(n_chunks, n_chunk_max);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_batch = params.n_batch;
|
||||
|
||||
|
@ -578,9 +587,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
std::vector<float> logits;
|
||||
|
||||
if(is_last_dev){
|
||||
if((my_rank == 0 || is_last_dev)){
|
||||
if (num_batches > 1) {
|
||||
logits.reserve((size_t)n_ctx * n_vocab);
|
||||
LOG_INF("%s: rank %d reserved logits space for %zu elements\n",
|
||||
__func__, my_rank, (size_t)n_ctx * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -590,9 +601,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
std::vector<uint16_t> log_probs; // the log probabilities of logits
|
||||
|
||||
// rank 0 and last rank store log_probs
|
||||
// only rank 0 or last device stores logits/log_probs
|
||||
if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) {
|
||||
// only rank 0 stores logits/log_probs
|
||||
if (!params.logits_file.empty() && (my_rank == 0)) {
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
log_probs.resize(n_ctx * nv);
|
||||
|
||||
|
@ -634,14 +644,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
clear_meta.clear_kv_cache = true;
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_send_meta(ctx, &clear_meta, false);
|
||||
llama_send_meta(ctx, &clear_meta);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
|
||||
if (llama_recv_meta(ctx, &clear_meta) == -1) {
|
||||
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
if (!is_last_dev) {
|
||||
llama_send_meta(ctx, &clear_meta, false);
|
||||
llama_send_meta(ctx, &clear_meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -700,12 +710,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
meta.n_outputs = n_outputs;
|
||||
|
||||
if (n_world > 1) {
|
||||
llama_send_meta(ctx, &meta, false); // reverse = false
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
} else {
|
||||
if (n_world > 1) {
|
||||
// comms: other ranks receive the batch meta data
|
||||
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
|
@ -732,7 +742,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
||||
if (is_last_dev && num_batches > 1 ) {
|
||||
if (my_rank == 0 && num_batches > 1 && n_outputs > 0) {
|
||||
const int n_outputs_synced = meta.n_outputs;
|
||||
if (n_outputs_synced > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
|
@ -754,7 +764,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
LOG("%.2f minutes\n", total_seconds / 60.0);
|
||||
}
|
||||
|
||||
if (is_last_dev) {
|
||||
if (my_rank == 0) {
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
|
||||
|
@ -785,37 +795,22 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
logits.clear();
|
||||
}
|
||||
|
||||
if (n_world > 1) {
|
||||
sync_meta done_meta;
|
||||
done_meta.chunk_done = true;
|
||||
if (my_rank == 0) {
|
||||
double current_ppl = std::exp(nll / count);
|
||||
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
|
||||
|
||||
if (is_last_dev) {
|
||||
// Last device sends completion signal upstream (reverse direction)
|
||||
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
|
||||
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
||||
} else if (my_rank == 0) {
|
||||
// Rank 0 waits for completion signal from downstream
|
||||
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
|
||||
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
||||
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
|
||||
} else {
|
||||
// Intermediate ranks: receive from downstream, relay upstream
|
||||
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
|
||||
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
||||
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
|
||||
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
||||
}
|
||||
LOG_INF("Rank 0: Chunk %d/%d (%.1f%%) completed, current_ppl = %.4lf\n",
|
||||
i + n_seq_batch, n_chunk, progress, current_ppl);
|
||||
} else {
|
||||
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
|
||||
|
||||
LOG_INF("Rank %d: Chunk %d/%d (%.1f%%) completed\n",
|
||||
my_rank, i + n_seq_batch, n_chunk, progress);
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
|
||||
if (is_last_dev) {
|
||||
if (my_rank == 0) {
|
||||
nll2 /= count;
|
||||
nll /= count;
|
||||
const double ppl = exp(nll);
|
||||
|
@ -2221,19 +2216,22 @@ int main(int argc, char ** argv) {
|
|||
LOG("\n");
|
||||
|
||||
|
||||
if (is_last_dev) {
|
||||
write_logfile(ctx, params, model, results);
|
||||
}
|
||||
|
||||
if (my_rank == 0) {
|
||||
write_logfile(ctx, params, model, results);
|
||||
llama_perf_context_print(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_world > 1) {
|
||||
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_free_sockets(ctx, nullptr);
|
||||
char * rank0_stop_signal = nullptr;
|
||||
llama_free_sockets(ctx, &rank0_stop_signal);
|
||||
|
||||
if (rank0_stop_signal) {
|
||||
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, rank0_stop_signal);
|
||||
delete[] rank0_stop_signal;
|
||||
}
|
||||
}
|
||||
|
||||
if (my_rank != 0 && signal_thread.joinable()) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue