Fix batch metadata chain forwarding in distributed perplexity

This commit is contained in:
leeetao 2025-07-18 14:05:34 +00:00
parent da31acbe6a
commit 50a916f123

View file

@ -525,26 +525,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (n_world > 1) { if (n_world > 1) {
sync_meta meta; sync_meta meta;
if (my_rank == 0) { if (my_rank != 0) {
meta.tokens_size = tokens_size;
meta.n_chunks = params.n_chunks;
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
llama_send_meta(ctx, &meta);
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
} else {
LOG_INF("%s: rank %d waiting 7 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
std::this_thread::sleep_for(std::chrono::milliseconds(7000));
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
if (llama_recv_meta(ctx, &meta) == -1) { if (llama_recv_meta(ctx, &meta) == -1) {
return { {}, -1.0, {}, {} }; return { {}, -1.0, {}, {} };
} }
tokens_size = meta.tokens_size; tokens_size = meta.tokens_size;
n_chunks = meta.n_chunks; n_chunks = meta.n_chunks;
if (!is_last_dev) { }
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
llama_send_meta(ctx, &meta); if (!is_last_dev) {
if (my_rank == 0) {
meta.tokens_size = tokens_size;
meta.n_chunks = params.n_chunks;
} }
llama_send_meta(ctx, &meta);
} }
} }
@ -708,13 +702,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
meta.seq_id = batch.seq_id; meta.seq_id = batch.seq_id;
meta.logits = batch.logits; meta.logits = batch.logits;
meta.n_outputs = n_outputs; meta.n_outputs = n_outputs;
}
if (n_world > 1) { // Chain forwarding pattern: consistent with tokens_size communication
llama_send_meta(ctx, &meta); if (n_world > 1) {
} if (my_rank != 0) {
} else { // Non-rank 0 devices receive batch meta data
if (n_world > 1) {
// comms: other ranks receive the batch meta data
if (llama_recv_meta(ctx, &meta) == -1) { if (llama_recv_meta(ctx, &meta) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}}; return {tokens, -1.0, {}, {}};
@ -733,7 +726,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
} }
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); } if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
} }
}
if (!is_last_dev) {
// Non-last devices forward the batch meta data
llama_send_meta(ctx, &meta);
} }
} }