Removed some unnecessary synchronization logic and added n_chunks communication content

This commit is contained in:
leeetao 2025-06-27 07:04:10 +00:00
parent a3becb586a
commit 48b7f53abb
4 changed files with 97 additions and 218 deletions

View file

@ -489,9 +489,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
// only last device store logits file
std::ofstream logits_stream;
if (my_rank == n_world - 1){
if (my_rank == 0) {
if (!params.logits_file.empty()) {
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
if (!logits_stream.is_open()) {
@ -506,9 +505,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<llama_token> tokens;
size_t tokens_size = 0;
int n_chunks = params.n_chunks;
// maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes
if (my_rank == 0 || is_last_dev) {
if (my_rank == 0) {
auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank);
@ -519,26 +518,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
}
{
if (n_world > 1) {
sync_meta meta;
if (my_rank == 0) {
if (my_rank != 0) {
LOG_INF("perplexity: rank %d waiting for rank 0 to be ready\n", my_rank);
}
if (n_world > 1) {
sync_meta meta;
if (my_rank == 0) {
meta.tokens_size = tokens_size;
llama_send_meta(ctx, &meta, false);
} else {
if (llama_recv_meta(ctx, &meta, false) == -1) {
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
meta.n_chunks = params.n_chunks;
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
llama_send_meta(ctx, &meta);
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
} else {
LOG_INF("%s: rank %d waiting 5 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
if (llama_recv_meta(ctx, &meta) == -1) {
return { {}, -1.0, {}, {} };
}
if (is_last_dev) {
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
} else {
tokens_size = meta.tokens_size;
llama_send_meta(ctx, &meta, false);
}
tokens_size = meta.tokens_size;
n_chunks = meta.n_chunks;
if (!is_last_dev) {
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
llama_send_meta(ctx, &meta);
}
}
}
LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size);
if (my_rank == 0) {
@ -553,14 +562,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> logit_history;
std::vector<float> prob_history;
if (is_last_dev) {
if (my_rank == 0) {
logit_history.resize(tokens_size);
prob_history.resize(tokens_size);
}
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk_max = tokens_size / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_chunk = n_chunks < 0 ? n_chunk_max : std::min(n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
@ -578,9 +587,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> logits;
if(is_last_dev){
if((my_rank == 0 || is_last_dev)){
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
LOG_INF("%s: rank %d reserved logits space for %zu elements\n",
__func__, my_rank, (size_t)n_ctx * n_vocab);
}
}
@ -590,9 +601,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<uint16_t> log_probs; // the log probabilities of logits
// rank 0 and last rank store log_probs
// only rank 0 or last device stores logits/log_probs
if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) {
// only rank 0 stores logits/log_probs
if (!params.logits_file.empty() && (my_rank == 0)) {
const int nv = 2*((n_vocab + 1)/2) + 4;
log_probs.resize(n_ctx * nv);
@ -634,14 +644,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
clear_meta.clear_kv_cache = true;
if (my_rank == 0) {
llama_send_meta(ctx, &clear_meta, false);
llama_send_meta(ctx, &clear_meta);
} else {
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
if (llama_recv_meta(ctx, &clear_meta) == -1) {
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
if (!is_last_dev) {
llama_send_meta(ctx, &clear_meta, false);
llama_send_meta(ctx, &clear_meta);
}
}
}
@ -700,12 +710,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
meta.n_outputs = n_outputs;
if (n_world > 1) {
llama_send_meta(ctx, &meta, false); // reverse = false
llama_send_meta(ctx, &meta);
}
} else {
if (n_world > 1) {
// comms: other ranks receive the batch meta data
if (llama_recv_meta(ctx, &meta, false) == -1) {
if (llama_recv_meta(ctx, &meta) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
@ -732,7 +742,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
return {tokens, -1, logit_history, prob_history};
}
if (is_last_dev && num_batches > 1 ) {
if (my_rank == 0 && num_batches > 1 && n_outputs > 0) {
const int n_outputs_synced = meta.n_outputs;
if (n_outputs_synced > 0) {
const auto * batch_logits = llama_get_logits(ctx);
@ -754,7 +764,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
LOG("%.2f minutes\n", total_seconds / 60.0);
}
if (is_last_dev) {
if (my_rank == 0) {
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
@ -785,37 +795,22 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
logits.clear();
}
if (n_world > 1) {
sync_meta done_meta;
done_meta.chunk_done = true;
if (my_rank == 0) {
double current_ppl = std::exp(nll / count);
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
if (is_last_dev) {
// Last device sends completion signal upstream (reverse direction)
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
} else if (my_rank == 0) {
// Rank 0 waits for completion signal from downstream
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
} else {
// Intermediate ranks: receive from downstream, relay upstream
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
}
LOG_INF("Rank 0: Chunk %d/%d (%.1f%%) completed, current_ppl = %.4lf\n",
i + n_seq_batch, n_chunk, progress, current_ppl);
} else {
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
LOG_INF("Rank %d: Chunk %d/%d (%.1f%%) completed\n",
my_rank, i + n_seq_batch, n_chunk, progress);
}
}
LOG("\n");
if (is_last_dev) {
if (my_rank == 0) {
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
@ -2221,19 +2216,22 @@ int main(int argc, char ** argv) {
LOG("\n");
if (is_last_dev) {
write_logfile(ctx, params, model, results);
}
if (my_rank == 0) {
write_logfile(ctx, params, model, results);
llama_perf_context_print(ctx);
}
}
if (n_world > 1) {
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
if (my_rank == 0) {
llama_free_sockets(ctx, nullptr);
char * rank0_stop_signal = nullptr;
llama_free_sockets(ctx, &rank0_stop_signal);
if (rank0_stop_signal) {
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, rank0_stop_signal);
delete[] rank0_stop_signal;
}
}
if (my_rank != 0 && signal_thread.joinable()) {