Fix batch metadata chain forwarding in distributed perplexity

This commit is contained in:
leeetao 2025-07-18 14:05:34 +00:00
parent da31acbe6a
commit 50a916f123

View file

@ -525,26 +525,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (n_world > 1) {
sync_meta meta;
if (my_rank == 0) {
meta.tokens_size = tokens_size;
meta.n_chunks = params.n_chunks;
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
llama_send_meta(ctx, &meta);
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
} else {
LOG_INF("%s: rank %d waiting 7 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
std::this_thread::sleep_for(std::chrono::milliseconds(7000));
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
if (my_rank != 0) {
if (llama_recv_meta(ctx, &meta) == -1) {
return { {}, -1.0, {}, {} };
}
tokens_size = meta.tokens_size;
n_chunks = meta.n_chunks;
if (!is_last_dev) {
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
llama_send_meta(ctx, &meta);
n_chunks = meta.n_chunks;
}
if (!is_last_dev) {
if (my_rank == 0) {
meta.tokens_size = tokens_size;
meta.n_chunks = params.n_chunks;
}
llama_send_meta(ctx, &meta);
}
}
@ -708,13 +702,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
meta.seq_id = batch.seq_id;
meta.logits = batch.logits;
meta.n_outputs = n_outputs;
}
if (n_world > 1) {
llama_send_meta(ctx, &meta);
}
} else {
if (n_world > 1) {
// comms: other ranks receive the batch meta data
// Chain forwarding pattern: consistent with tokens_size communication
if (n_world > 1) {
if (my_rank != 0) {
// Non-rank 0 devices receive batch meta data
if (llama_recv_meta(ctx, &meta) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
@ -733,7 +726,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
}
}
if (!is_last_dev) {
// Non-last devices forward the batch meta data
llama_send_meta(ctx, &meta);
}
}