mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 21:39:02 +00:00
Fix batch metadata chain forwarding in distributed perplexity
This commit is contained in:
parent
da31acbe6a
commit
50a916f123
1 changed files with 20 additions and 23 deletions
|
@ -525,26 +525,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
sync_meta meta;
|
sync_meta meta;
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank != 0) {
|
||||||
meta.tokens_size = tokens_size;
|
|
||||||
meta.n_chunks = params.n_chunks;
|
|
||||||
|
|
||||||
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
|
|
||||||
llama_send_meta(ctx, &meta);
|
|
||||||
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
|
|
||||||
} else {
|
|
||||||
LOG_INF("%s: rank %d waiting 7 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(7000));
|
|
||||||
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
|
|
||||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||||
return { {}, -1.0, {}, {} };
|
return { {}, -1.0, {}, {} };
|
||||||
}
|
}
|
||||||
tokens_size = meta.tokens_size;
|
tokens_size = meta.tokens_size;
|
||||||
n_chunks = meta.n_chunks;
|
n_chunks = meta.n_chunks;
|
||||||
if (!is_last_dev) {
|
}
|
||||||
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
|
|
||||||
llama_send_meta(ctx, &meta);
|
if (!is_last_dev) {
|
||||||
|
if (my_rank == 0) {
|
||||||
|
meta.tokens_size = tokens_size;
|
||||||
|
meta.n_chunks = params.n_chunks;
|
||||||
}
|
}
|
||||||
|
llama_send_meta(ctx, &meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -708,13 +702,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
meta.seq_id = batch.seq_id;
|
meta.seq_id = batch.seq_id;
|
||||||
meta.logits = batch.logits;
|
meta.logits = batch.logits;
|
||||||
meta.n_outputs = n_outputs;
|
meta.n_outputs = n_outputs;
|
||||||
|
}
|
||||||
|
|
||||||
if (n_world > 1) {
|
// Chain forwarding pattern: consistent with tokens_size communication
|
||||||
llama_send_meta(ctx, &meta);
|
if (n_world > 1) {
|
||||||
}
|
if (my_rank != 0) {
|
||||||
} else {
|
// Non-rank 0 devices receive batch meta data
|
||||||
if (n_world > 1) {
|
|
||||||
// comms: other ranks receive the batch meta data
|
|
||||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
|
@ -733,7 +726,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
}
|
}
|
||||||
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
|
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_last_dev) {
|
||||||
|
// Non-last devices forward the batch meta data
|
||||||
|
llama_send_meta(ctx, &meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue