Refactored the logic related to communication content and timing control

This commit is contained in:
leeetao 2025-06-24 10:40:37 +00:00
parent 4b823775ec
commit a3becb586a
5 changed files with 474 additions and 134 deletions

View file

@ -524,9 +524,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
sync_meta meta;
if (my_rank == 0) {
meta.tokens_size = tokens_size;
llama_send_meta(ctx, &meta);
llama_send_meta(ctx, &meta, false);
} else {
if (llama_recv_meta(ctx, &meta) == -1) {
if (llama_recv_meta(ctx, &meta, false) == -1) {
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
return { {}, -1.0, {}, {} };
}
@ -534,7 +534,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
} else {
tokens_size = meta.tokens_size;
llama_send_meta(ctx, &meta);
llama_send_meta(ctx, &meta, false);
}
}
}
@ -628,19 +628,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
{
// synvhronize the KV cache clear signal across all ranks
if (n_world > 1) {
sync_meta clear_meta;
clear_meta.clear_kv_cache = true;
if (my_rank == 0) {
llama_send_meta(ctx, &clear_meta);
llama_send_meta(ctx, &clear_meta, false);
} else {
if (llama_recv_meta(ctx, &clear_meta) == -1) {
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
if (!is_last_dev) {
llama_send_meta(ctx, &clear_meta);
llama_send_meta(ctx, &clear_meta, false);
}
}
}
@ -648,11 +649,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// clear the KV cache
llama_kv_cache_clear(ctx);
sync_meta meta;
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
// used for communication of the batch meta data
sync_meta meta;
int n_outputs = 0;
@ -689,38 +690,42 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
}
if (my_rank == 0) {
// Required batch info: Operation scale, KV cache location, Logits calculation location
meta.n_ctx = n_ctx;
// comms: now rank 0 need to send the batch to other ranks
meta.n_tokens = batch.n_tokens;
meta.pos = batch.pos;
meta.n_seq_id = batch.n_seq_id;
meta.seq_id = batch.seq_id;
meta.logits = batch.logits;
meta.all_pos_0 = batch.all_pos_0;
meta.all_pos_1 = batch.all_pos_1;
meta.n_outputs = n_outputs;
meta.chunk_start_pos = start;
}
// other ranks need to know batch info
{
if (n_world > 1) {
meta.n_ctx = n_ctx;
if (my_rank == 0) {
llama_send_meta(ctx, &meta);
} else {
if (llama_recv_meta(ctx, &meta) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
if (!is_last_dev) {
llama_send_meta(ctx, &meta);
}
}
llama_send_meta(ctx, &meta, false); // reverse = false
}
}
} else {
if (n_world > 1) {
// comms: other ranks receive the batch meta data
if (llama_recv_meta(ctx, &meta, false) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
// copy the batch meta data to the llama_batch
if (meta.n_tokens > 0) {
batch.n_tokens = meta.n_tokens;
if (meta.pos) { std::memcpy(batch.pos, meta.pos, meta.n_tokens * sizeof(llama_pos)); } // use n_tokens instead of n_batch, n_tokens is the actual number of tokens in the batch
if (meta.n_seq_id) { std::memcpy(batch.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t)); }
if (meta.seq_id) {
const int32_t n_seq_max = 1;
for (int32_t i = 0; i < meta.n_tokens; ++i) {
std::memcpy(batch.seq_id[i], meta.seq_id[i], n_seq_max * sizeof(llama_seq_id));
}
}
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
}
}
}
if (llama_decode(ctx, batch)) {
LOG_INF("%s : failed to eval\n", __func__);
@ -753,8 +758,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
int chunk_start_pos = meta.chunk_start_pos;
llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first;
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits,
tokens_data, n_ctx - 1 - first,
@ -763,8 +767,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
process_logits(n_vocab, all_logits,
tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + chunk_start_pos + seq*n_ctx + first,
prob_history.data() + chunk_start_pos + seq*n_ctx + first);
logit_history.data() + start + seq*n_ctx + first,
prob_history.data() + start + seq*n_ctx + first);
}
count += n_ctx - first - 1;
@ -778,8 +782,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
}
logits.clear();
}
if (n_world > 1) {
sync_meta done_meta;
done_meta.chunk_done = true;
if (is_last_dev) {
// Last device sends completion signal upstream (reverse direction)
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
} else if (my_rank == 0) {
// Rank 0 waits for completion signal from downstream
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
} else {
// Intermediate ranks: receive from downstream, relay upstream
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
}
}
logits.clear();
}
LOG("\n");
@ -795,12 +827,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
} else {
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
}
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history};
}
llama_batch_free(batch);
return {};
}
@ -2078,10 +2109,11 @@ int main(int argc, char ** argv) {
params.n_ctx = 512;
params.logits_all = true;
params.escape = false;
params.is_perplexity_eval = true;
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
return 1;
}
}
uint32_t n_world = params.n_world;
uint32_t my_rank = params.rank;
@ -2141,7 +2173,6 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
llama_init_result llama_init = llama_init_from_gpt_params(params);
// update rank and world size if any devices removed
my_rank = params.rank;
n_world = params.n_world;
@ -2189,14 +2220,34 @@ int main(int argc, char ** argv) {
LOG("\n");
if (is_last_dev) {
llama_perf_context_print(ctx);
write_logfile(ctx, params, model, results);
}
if (my_rank == 0) {
llama_perf_context_print(ctx);
}
if (n_world > 1) {
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
if (my_rank == 0) {
llama_free_sockets(ctx, nullptr);
}
if (my_rank != 0 && signal_thread.joinable()) {
signal_thread.join();
}
if (stop_signal) {
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, stop_signal);
delete[] stop_signal;
}
}
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
return 0;