From 50a916f12356b7f09adb0d2ab465766a06d17339 Mon Sep 17 00:00:00 2001 From: leeetao <3122669219@qq.com> Date: Fri, 18 Jul 2025 14:05:34 +0000 Subject: [PATCH] Fix batch metadata chain forwarding in distributed perplexity --- examples/perplexity/perplexity.cpp | 43 ++++++++++++++---------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 550b1b20..c33f78c9 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -525,26 +525,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par if (n_world > 1) { sync_meta meta; - if (my_rank == 0) { - meta.tokens_size = tokens_size; - meta.n_chunks = params.n_chunks; - - LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size); - llama_send_meta(ctx, &meta); - LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__); - } else { - LOG_INF("%s: rank %d waiting 7 seconds for rank 0 to complete tokenization\n", __func__, my_rank); - std::this_thread::sleep_for(std::chrono::milliseconds(7000)); - LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank); + if (my_rank != 0) { if (llama_recv_meta(ctx, &meta) == -1) { return { {}, -1.0, {}, {} }; } tokens_size = meta.tokens_size; - n_chunks = meta.n_chunks; - if (!is_last_dev) { - LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank); - llama_send_meta(ctx, &meta); + n_chunks = meta.n_chunks; + } + + if (!is_last_dev) { + if (my_rank == 0) { + meta.tokens_size = tokens_size; + meta.n_chunks = params.n_chunks; } + llama_send_meta(ctx, &meta); } } @@ -708,13 +702,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par meta.seq_id = batch.seq_id; meta.logits = batch.logits; meta.n_outputs = n_outputs; + } - if (n_world > 1) { - llama_send_meta(ctx, &meta); - } - } else { - if (n_world > 1) { - // comms: other ranks receive the batch meta data + // Chain forwarding pattern: consistent with tokens_size communication + if (n_world > 1) { + if (my_rank != 0) { + // Non-rank 0 devices receive batch meta data if (llama_recv_meta(ctx, &meta) == -1) { LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); return {tokens, -1.0, {}, {}}; @@ -732,8 +725,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } } if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); } - } - + } + } + + if (!is_last_dev) { + // Non-last devices forward the batch meta data + llama_send_meta(ctx, &meta); } }