mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 19:09:06 +00:00
Fix batch metadata chain forwarding in distributed perplexity
This commit is contained in:
parent
da31acbe6a
commit
50a916f123
1 changed files with 20 additions and 23 deletions
|
@ -525,26 +525,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
if (n_world > 1) {
|
||||
sync_meta meta;
|
||||
|
||||
if (my_rank == 0) {
|
||||
meta.tokens_size = tokens_size;
|
||||
meta.n_chunks = params.n_chunks;
|
||||
|
||||
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
|
||||
llama_send_meta(ctx, &meta);
|
||||
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
|
||||
} else {
|
||||
LOG_INF("%s: rank %d waiting 7 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(7000));
|
||||
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
|
||||
if (my_rank != 0) {
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
return { {}, -1.0, {}, {} };
|
||||
}
|
||||
tokens_size = meta.tokens_size;
|
||||
n_chunks = meta.n_chunks;
|
||||
if (!is_last_dev) {
|
||||
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
|
||||
llama_send_meta(ctx, &meta);
|
||||
n_chunks = meta.n_chunks;
|
||||
}
|
||||
|
||||
if (!is_last_dev) {
|
||||
if (my_rank == 0) {
|
||||
meta.tokens_size = tokens_size;
|
||||
meta.n_chunks = params.n_chunks;
|
||||
}
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -708,13 +702,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
meta.seq_id = batch.seq_id;
|
||||
meta.logits = batch.logits;
|
||||
meta.n_outputs = n_outputs;
|
||||
}
|
||||
|
||||
if (n_world > 1) {
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
} else {
|
||||
if (n_world > 1) {
|
||||
// comms: other ranks receive the batch meta data
|
||||
// Chain forwarding pattern: consistent with tokens_size communication
|
||||
if (n_world > 1) {
|
||||
if (my_rank != 0) {
|
||||
// Non-rank 0 devices receive batch meta data
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
|
@ -733,7 +726,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
}
|
||||
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_last_dev) {
|
||||
// Non-last devices forward the batch meta data
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue