mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-10 20:24:34 +00:00
Refactored the logic related to communication content and timing control
This commit is contained in:
parent
4b823775ec
commit
a3becb586a
5 changed files with 474 additions and 134 deletions
|
@ -524,9 +524,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
sync_meta meta;
|
||||
if (my_rank == 0) {
|
||||
meta.tokens_size = tokens_size;
|
||||
llama_send_meta(ctx, &meta);
|
||||
llama_send_meta(ctx, &meta, false);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
||||
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
|
||||
return { {}, -1.0, {}, {} };
|
||||
}
|
||||
|
@ -534,7 +534,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
|
||||
} else {
|
||||
tokens_size = meta.tokens_size;
|
||||
llama_send_meta(ctx, &meta);
|
||||
llama_send_meta(ctx, &meta, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -628,19 +628,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
{
|
||||
// synvhronize the KV cache clear signal across all ranks
|
||||
if (n_world > 1) {
|
||||
sync_meta clear_meta;
|
||||
clear_meta.clear_kv_cache = true;
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_send_meta(ctx, &clear_meta);
|
||||
llama_send_meta(ctx, &clear_meta, false);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &clear_meta) == -1) {
|
||||
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
|
||||
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
if (!is_last_dev) {
|
||||
llama_send_meta(ctx, &clear_meta);
|
||||
llama_send_meta(ctx, &clear_meta, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -648,11 +649,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
// clear the KV cache
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
sync_meta meta;
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
// used for communication of the batch meta data
|
||||
sync_meta meta;
|
||||
|
||||
int n_outputs = 0;
|
||||
|
||||
|
@ -689,38 +690,42 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
}
|
||||
|
||||
if (my_rank == 0) {
|
||||
// Required batch info: Operation scale, KV cache location, Logits calculation location
|
||||
meta.n_ctx = n_ctx;
|
||||
|
||||
// comms: now rank 0 need to send the batch to other ranks
|
||||
meta.n_tokens = batch.n_tokens;
|
||||
meta.pos = batch.pos;
|
||||
meta.n_seq_id = batch.n_seq_id;
|
||||
meta.seq_id = batch.seq_id;
|
||||
meta.logits = batch.logits;
|
||||
meta.all_pos_0 = batch.all_pos_0;
|
||||
meta.all_pos_1 = batch.all_pos_1;
|
||||
meta.n_outputs = n_outputs;
|
||||
meta.chunk_start_pos = start;
|
||||
}
|
||||
|
||||
// other ranks need to know batch info
|
||||
{
|
||||
if (n_world > 1) {
|
||||
meta.n_ctx = n_ctx;
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_send_meta(ctx, &meta);
|
||||
} else {
|
||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
if (!is_last_dev) {
|
||||
llama_send_meta(ctx, &meta);
|
||||
}
|
||||
}
|
||||
llama_send_meta(ctx, &meta, false); // reverse = false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (n_world > 1) {
|
||||
// comms: other ranks receive the batch meta data
|
||||
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
|
||||
// copy the batch meta data to the llama_batch
|
||||
if (meta.n_tokens > 0) {
|
||||
batch.n_tokens = meta.n_tokens;
|
||||
if (meta.pos) { std::memcpy(batch.pos, meta.pos, meta.n_tokens * sizeof(llama_pos)); } // use n_tokens instead of n_batch, n_tokens is the actual number of tokens in the batch
|
||||
if (meta.n_seq_id) { std::memcpy(batch.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t)); }
|
||||
if (meta.seq_id) {
|
||||
const int32_t n_seq_max = 1;
|
||||
for (int32_t i = 0; i < meta.n_tokens; ++i) {
|
||||
std::memcpy(batch.seq_id[i], meta.seq_id[i], n_seq_max * sizeof(llama_seq_id));
|
||||
}
|
||||
}
|
||||
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_INF("%s : failed to eval\n", __func__);
|
||||
|
@ -753,8 +758,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
|
||||
int chunk_start_pos = meta.chunk_start_pos;
|
||||
llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first;
|
||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||
if (!params.logits_file.empty()) {
|
||||
process_logits(logits_stream, n_vocab, all_logits,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
|
@ -763,8 +767,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
process_logits(n_vocab, all_logits,
|
||||
tokens_data, n_ctx - 1 - first,
|
||||
workers, nll, nll2,
|
||||
logit_history.data() + chunk_start_pos + seq*n_ctx + first,
|
||||
prob_history.data() + chunk_start_pos + seq*n_ctx + first);
|
||||
logit_history.data() + start + seq*n_ctx + first,
|
||||
prob_history.data() + start + seq*n_ctx + first);
|
||||
}
|
||||
count += n_ctx - first - 1;
|
||||
|
||||
|
@ -778,8 +782,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||
}
|
||||
}
|
||||
logits.clear();
|
||||
}
|
||||
|
||||
if (n_world > 1) {
|
||||
sync_meta done_meta;
|
||||
done_meta.chunk_done = true;
|
||||
|
||||
if (is_last_dev) {
|
||||
// Last device sends completion signal upstream (reverse direction)
|
||||
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
|
||||
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
||||
} else if (my_rank == 0) {
|
||||
// Rank 0 waits for completion signal from downstream
|
||||
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
|
||||
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
||||
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
|
||||
} else {
|
||||
// Intermediate ranks: receive from downstream, relay upstream
|
||||
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
|
||||
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
||||
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
|
||||
return {tokens, -1.0, {}, {}};
|
||||
}
|
||||
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
|
||||
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
||||
}
|
||||
}
|
||||
logits.clear();
|
||||
}
|
||||
LOG("\n");
|
||||
|
||||
|
@ -795,12 +827,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
} else {
|
||||
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
return {tokens, ppl, logit_history, prob_history};
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -2078,10 +2109,11 @@ int main(int argc, char ** argv) {
|
|||
params.n_ctx = 512;
|
||||
params.logits_all = true;
|
||||
params.escape = false;
|
||||
params.is_perplexity_eval = true;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t n_world = params.n_world;
|
||||
uint32_t my_rank = params.rank;
|
||||
|
@ -2141,7 +2173,6 @@ int main(int argc, char ** argv) {
|
|||
// load the model and apply lora adapter, if any
|
||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||
|
||||
// update rank and world size if any devices removed
|
||||
my_rank = params.rank;
|
||||
n_world = params.n_world;
|
||||
|
||||
|
@ -2189,14 +2220,34 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG("\n");
|
||||
|
||||
|
||||
if (is_last_dev) {
|
||||
llama_perf_context_print(ctx);
|
||||
write_logfile(ctx, params, model, results);
|
||||
}
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_perf_context_print(ctx);
|
||||
}
|
||||
|
||||
if (n_world > 1) {
|
||||
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
|
||||
|
||||
if (my_rank == 0) {
|
||||
llama_free_sockets(ctx, nullptr);
|
||||
}
|
||||
|
||||
if (my_rank != 0 && signal_thread.joinable()) {
|
||||
signal_thread.join();
|
||||
}
|
||||
|
||||
if (stop_signal) {
|
||||
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, stop_signal);
|
||||
delete[] stop_signal;
|
||||
}
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue