Refactored the logic related to communication content and timing control

This commit is contained in:
leeetao 2025-06-24 10:40:37 +00:00
parent 4b823775ec
commit a3becb586a
5 changed files with 474 additions and 134 deletions

View file

@ -1992,6 +1992,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.keep_out_in_metal = params.keep_out_in_metal; cparams.keep_out_in_metal = params.keep_out_in_metal;
cparams.n_gpu_layers = params.n_gpu_layers; cparams.n_gpu_layers = params.n_gpu_layers;
cparams.n_cycles = params.n_cycles; cparams.n_cycles = params.n_cycles;
cparams.is_perplexity_eval= params.is_perplexity_eval;
std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), cparams.n_layer_window); std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), cparams.n_layer_window);
if (cparams.master_ip != nullptr) { if (cparams.master_ip != nullptr) {

View file

@ -178,6 +178,8 @@ struct gpt_params {
int32_t yarn_orig_ctx = 0; // YaRN original context length int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold float defrag_thold = -1.0f; // KV cache defragmentation threshold
bool is_perplexity_eval;
struct cpu_params cpuparams; struct cpu_params cpuparams;
struct cpu_params cpuparams_batch; struct cpu_params cpuparams_batch;
struct cpu_params draft_cpuparams; struct cpu_params draft_cpuparams;

View file

@ -524,9 +524,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
sync_meta meta; sync_meta meta;
if (my_rank == 0) { if (my_rank == 0) {
meta.tokens_size = tokens_size; meta.tokens_size = tokens_size;
llama_send_meta(ctx, &meta); llama_send_meta(ctx, &meta, false);
} else { } else {
if (llama_recv_meta(ctx, &meta) == -1) { if (llama_recv_meta(ctx, &meta, false) == -1) {
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank); LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
return { {}, -1.0, {}, {} }; return { {}, -1.0, {}, {} };
} }
@ -534,7 +534,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!"); GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
} else { } else {
tokens_size = meta.tokens_size; tokens_size = meta.tokens_size;
llama_send_meta(ctx, &meta); llama_send_meta(ctx, &meta, false);
} }
} }
} }
@ -628,19 +628,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
{ {
// synvhronize the KV cache clear signal across all ranks
if (n_world > 1) { if (n_world > 1) {
sync_meta clear_meta; sync_meta clear_meta;
clear_meta.clear_kv_cache = true; clear_meta.clear_kv_cache = true;
if (my_rank == 0) { if (my_rank == 0) {
llama_send_meta(ctx, &clear_meta); llama_send_meta(ctx, &clear_meta, false);
} else { } else {
if (llama_recv_meta(ctx, &clear_meta) == -1) { if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank); LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}}; return {tokens, -1.0, {}, {}};
} }
if (!is_last_dev) { if (!is_last_dev) {
llama_send_meta(ctx, &clear_meta); llama_send_meta(ctx, &clear_meta, false);
} }
} }
} }
@ -648,11 +649,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
sync_meta meta;
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
// used for communication of the batch meta data
sync_meta meta;
int n_outputs = 0; int n_outputs = 0;
@ -689,38 +690,42 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS // restore the original token in case it was set to BOS
tokens[seq_start] = token_org; tokens[seq_start] = token_org;
} }
}
// comms: now rank 0 need to send the batch to other ranks
if (my_rank == 0) {
// Required batch info: Operation scale, KV cache location, Logits calculation location
meta.n_ctx = n_ctx;
meta.n_tokens = batch.n_tokens; meta.n_tokens = batch.n_tokens;
meta.pos = batch.pos; meta.pos = batch.pos;
meta.n_seq_id = batch.n_seq_id;
meta.seq_id = batch.seq_id;
meta.logits = batch.logits; meta.logits = batch.logits;
meta.all_pos_0 = batch.all_pos_0;
meta.all_pos_1 = batch.all_pos_1;
meta.n_outputs = n_outputs; meta.n_outputs = n_outputs;
meta.chunk_start_pos = start;
}
// other ranks need to know batch info
{
if (n_world > 1) { if (n_world > 1) {
meta.n_ctx = n_ctx; llama_send_meta(ctx, &meta, false); // reverse = false
if (my_rank == 0) {
llama_send_meta(ctx, &meta);
} else {
if (llama_recv_meta(ctx, &meta) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
if (!is_last_dev) {
llama_send_meta(ctx, &meta);
}
}
} }
} } else {
if (n_world > 1) {
// comms: other ranks receive the batch meta data
if (llama_recv_meta(ctx, &meta, false) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
// copy the batch meta data to the llama_batch
if (meta.n_tokens > 0) {
batch.n_tokens = meta.n_tokens;
if (meta.pos) { std::memcpy(batch.pos, meta.pos, meta.n_tokens * sizeof(llama_pos)); } // use n_tokens instead of n_batch, n_tokens is the actual number of tokens in the batch
if (meta.n_seq_id) { std::memcpy(batch.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t)); }
if (meta.seq_id) {
const int32_t n_seq_max = 1;
for (int32_t i = 0; i < meta.n_tokens; ++i) {
std::memcpy(batch.seq_id[i], meta.seq_id[i], n_seq_max * sizeof(llama_seq_id));
}
}
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
}
}
}
if (llama_decode(ctx, batch)) { if (llama_decode(ctx, batch)) {
LOG_INF("%s : failed to eval\n", __func__); LOG_INF("%s : failed to eval\n", __func__);
@ -753,8 +758,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
for (int seq = 0; seq < n_seq_batch; seq++) { for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
int chunk_start_pos = meta.chunk_start_pos; llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits, process_logits(logits_stream, n_vocab, all_logits,
tokens_data, n_ctx - 1 - first, tokens_data, n_ctx - 1 - first,
@ -763,8 +767,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
process_logits(n_vocab, all_logits, process_logits(n_vocab, all_logits,
tokens_data, n_ctx - 1 - first, tokens_data, n_ctx - 1 - first,
workers, nll, nll2, workers, nll, nll2,
logit_history.data() + chunk_start_pos + seq*n_ctx + first, logit_history.data() + start + seq*n_ctx + first,
prob_history.data() + chunk_start_pos + seq*n_ctx + first); prob_history.data() + start + seq*n_ctx + first);
} }
count += n_ctx - first - 1; count += n_ctx - first - 1;
@ -778,8 +782,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
} }
} }
logits.clear();
}
if (n_world > 1) {
sync_meta done_meta;
done_meta.chunk_done = true;
if (is_last_dev) {
// Last device sends completion signal upstream (reverse direction)
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
} else if (my_rank == 0) {
// Rank 0 waits for completion signal from downstream
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
} else {
// Intermediate ranks: receive from downstream, relay upstream
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
}
} }
logits.clear();
} }
LOG("\n"); LOG("\n");
@ -795,12 +827,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
} else { } else {
LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
} }
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history}; return {tokens, ppl, logit_history, prob_history};
} }
llama_batch_free(batch); llama_batch_free(batch);
return {}; return {};
} }
@ -2078,10 +2109,11 @@ int main(int argc, char ** argv) {
params.n_ctx = 512; params.n_ctx = 512;
params.logits_all = true; params.logits_all = true;
params.escape = false; params.escape = false;
params.is_perplexity_eval = true;
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
return 1; return 1;
} }
uint32_t n_world = params.n_world; uint32_t n_world = params.n_world;
uint32_t my_rank = params.rank; uint32_t my_rank = params.rank;
@ -2141,7 +2173,6 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
llama_init_result llama_init = llama_init_from_gpt_params(params); llama_init_result llama_init = llama_init_from_gpt_params(params);
// update rank and world size if any devices removed
my_rank = params.rank; my_rank = params.rank;
n_world = params.n_world; n_world = params.n_world;
@ -2189,14 +2220,34 @@ int main(int argc, char ** argv) {
LOG("\n"); LOG("\n");
if (is_last_dev) { if (is_last_dev) {
llama_perf_context_print(ctx);
write_logfile(ctx, params, model, results); write_logfile(ctx, params, model, results);
}
if (my_rank == 0) {
llama_perf_context_print(ctx);
}
if (n_world > 1) {
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
if (my_rank == 0) {
llama_free_sockets(ctx, nullptr);
}
if (my_rank != 0 && signal_thread.joinable()) {
signal_thread.join();
}
if (stop_signal) {
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, stop_signal);
delete[] stop_signal;
}
} }
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
llama_backend_free(); llama_backend_free();
return 0; return 0;

View file

@ -58,12 +58,15 @@ struct sync_meta {
int8_t * logits = nullptr; int8_t * logits = nullptr;
llama_pos * pos = nullptr; llama_pos * pos = nullptr;
int32_t * n_seq_id = nullptr;
llama_seq_id ** seq_id = nullptr;
llama_pos all_pos_0; llama_pos all_pos_0;
llama_pos all_pos_1; llama_pos all_pos_1;
uint32_t n_ctx = 0; uint32_t n_ctx = 0;
int chunk_start_pos; // used for perplexity evaluation
int32_t n_outputs; // Used to pass the number of logits to be outputted int32_t n_outputs;
bool chunk_done = false; // signal that the chunk is done
// signal to clear the kv cache // signal to clear the kv cache
bool clear_kv_cache= false; bool clear_kv_cache= false;
@ -389,6 +392,7 @@ extern "C" {
int32_t n_threads; // number of threads to use for generation int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing int32_t n_threads_batch; // number of threads to use for batch processing
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_attention_type attention_type; // attention type to use for embeddings
@ -422,6 +426,7 @@ extern "C" {
// currently works only with CPU execution // currently works only with CPU execution
ggml_abort_callback abort_callback; ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
bool is_perplexity_eval; // whether to run in perplexity evaluation mode
}; };
// model quantization parameters // model quantization parameters
@ -502,8 +507,8 @@ extern "C" {
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank); LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg); LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta); LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta); LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set); LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);

View file

@ -2607,6 +2607,8 @@ struct llama_cparams {
int n_threads; // number of threads to use for generation int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing int n_threads_batch; // number of threads to use for batch processing
bool is_perplexity_eval;
float rope_freq_base; float rope_freq_base;
float rope_freq_scale; float rope_freq_scale;
@ -3490,6 +3492,7 @@ struct llama_context {
// sockets // sockets
std::string master_ip = "localhost"; std::string master_ip = "localhost";
std::string next_node_ip = "localhost"; std::string next_node_ip = "localhost";
std::string prev_node_ip = "localhost";
uint32_t data_port = 9000; uint32_t data_port = 9000;
uint32_t signal_port = 10000; uint32_t signal_port = 10000;
zmq::context_t * sock_context = nullptr; zmq::context_t * sock_context = nullptr;
@ -3497,6 +3500,9 @@ struct llama_context {
zmq::socket_t * recv_socket = nullptr; zmq::socket_t * recv_socket = nullptr;
zmq::socket_t * master_socket = nullptr; zmq::socket_t * master_socket = nullptr;
zmq::socket_t * signal_socket = nullptr; zmq::socket_t * signal_socket = nullptr;
// Add these for reverse communication
zmq::socket_t * reverse_send_socket = nullptr; // Reverse: Rank i -> Rank i-1
zmq::socket_t * reverse_recv_socket = nullptr; // Reverse: Rank i <- Rank i+1
}; };
struct llama_lora_weight { struct llama_lora_weight {
@ -17866,30 +17872,126 @@ struct input_tensors {
ggml_tensor * inp_pos; ggml_tensor * inp_pos;
}; };
void llama_send_meta(llama_context * ctx, struct sync_meta * meta) { void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
GGML_ASSERT(ctx != nullptr); GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(meta != nullptr); GGML_ASSERT(meta != nullptr);
zmq::socket_t * send_socket = ctx->send_socket; zmq::socket_t * send_socket = reverse ? ctx->reverse_send_socket : ctx->send_socket;
GGML_ASSERT(send_socket != nullptr); GGML_ASSERT(send_socket != nullptr);
try { try {
std::vector<zmq::message_t> send_msgs; std::vector<zmq::message_t> send_msgs;
GGML_ASSERT(meta->n_tokens != 0); // Handle chunk_done signal
send_msgs.emplace_back("n_tokens", strlen("n_tokens")); if (meta->chunk_done) {
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens)); send_msgs.emplace_back("chunk_done", strlen("chunk_done"));
send_msgs.emplace_back("1", 1);
if (meta->pos != nullptr) { zmq::send_multipart(*send_socket, send_msgs);
send_msgs.emplace_back("pos", strlen("pos")); return;
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
} }
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); if (meta->clear_kv_cache) {
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache"));
send_msgs.emplace_back("1", 1);
return;
}
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); if (meta->kv_seq_rm) {
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); send_msgs.emplace_back("kv_seq_rm", strlen("kv_seq_rm"));
send_msgs.emplace_back(&(meta->rm_seq_id), sizeof(meta->rm_seq_id));
send_msgs.emplace_back(&(meta->rm_p0), sizeof(meta->rm_p0));
send_msgs.emplace_back(&(meta->rm_p1), sizeof(meta->rm_p1));
zmq::send_multipart(*send_socket, send_msgs);
return;
}
if (meta->kv_seq_add) {
send_msgs.emplace_back("kv_seq_add", strlen("kv_seq_add"));
send_msgs.emplace_back(&(meta->add_seq_id), sizeof(meta->add_seq_id));
send_msgs.emplace_back(&(meta->add_p0), sizeof(meta->add_p0));
send_msgs.emplace_back(&(meta->add_p1), sizeof(meta->add_p1));
send_msgs.emplace_back(&(meta->add_delta), sizeof(meta->add_delta));
zmq::send_multipart(*send_socket, send_msgs);
return;
}
if (meta->kv_seq_cp) {
send_msgs.emplace_back("kv_seq_cp", strlen("kv_seq_cp"));
send_msgs.emplace_back(&(meta->cp_src_seq_id), sizeof(meta->cp_src_seq_id));
send_msgs.emplace_back(&(meta->cp_dst_seq_id), sizeof(meta->cp_dst_seq_id));
send_msgs.emplace_back(&(meta->cp_p0), sizeof(meta->cp_p0));
send_msgs.emplace_back(&(meta->cp_p1), sizeof(meta->cp_p1));
zmq::send_multipart(*send_socket, send_msgs);
return;
}
if (meta->kv_seq_div) {
send_msgs.emplace_back("kv_seq_div", strlen("kv_seq_div"));
send_msgs.emplace_back(&(meta->div_seq_id), sizeof(meta->div_seq_id));
send_msgs.emplace_back(&(meta->div_p0), sizeof(meta->div_p0));
send_msgs.emplace_back(&(meta->div_p1), sizeof(meta->div_p1));
send_msgs.emplace_back(&(meta->div_factor), sizeof(meta->div_factor));
zmq::send_multipart(*send_socket, send_msgs);
return;
}
if (meta->tokens_size > 0) {
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
}
if (meta->n_tokens > 0) {
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
send_msgs.emplace_back("n_outputs", strlen("n_outputs"));
send_msgs.emplace_back(&(meta->n_outputs), sizeof(meta->n_outputs));
// send_msgs.emplace_back("chunk_start_pos", strlen("chunk_start_pos"));
// send_msgs.emplace_back(&(meta->chunk_start_pos), sizeof(meta->chunk_start_pos));
send_msgs.emplace_back("n_ctx", strlen("n_ctx"));
send_msgs.emplace_back(&(meta->n_ctx), sizeof(meta->n_ctx));
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
// batch.pos
if (meta->pos != nullptr) {
send_msgs.emplace_back("pos", strlen("pos"));
send_msgs.emplace_back(meta->pos, meta->n_tokens * sizeof(llama_pos));
}
// batch.n_seq_id
if (meta->n_seq_id != nullptr) {
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
}
// batch.seq_id
if (meta->seq_id != nullptr) {
const int32_t n_tokens = meta->n_tokens;
const int32_t n_seq_max = 1;
std::vector<llama_seq_id> flat_seq_ids;
flat_seq_ids.reserve(n_tokens * n_seq_max);
for (int32_t i = 0; i < n_tokens; ++i) {
for (int32_t j = 0; j < n_seq_max; ++j) {
flat_seq_ids.push_back(meta->seq_id[i][j]);
}
}
send_msgs.emplace_back("seq_id", strlen("seq_id"));
send_msgs.emplace_back(flat_seq_ids.data(), flat_seq_ids.size() * sizeof(llama_seq_id));
}
// batch.logits
if (meta->logits != nullptr) {
send_msgs.emplace_back("logits", strlen("logits"));
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
}
}
if (!send_msgs.empty()) { if (!send_msgs.empty()) {
zmq::send_multipart(*send_socket, send_msgs); zmq::send_multipart(*send_socket, send_msgs);
@ -17899,12 +18001,16 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
} }
} }
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
ctx->recv_socket->set(zmq::sockopt::rcvtimeo, 1000); zmq::socket_t * recv_socket = reverse ? ctx->reverse_recv_socket : ctx->recv_socket;
GGML_ASSERT(recv_socket != nullptr);
recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
std::vector<zmq::message_t> recv_msgs; std::vector<zmq::message_t> recv_msgs;
if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) { if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) {
recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error
return -1; return -1;
} }
@ -17913,6 +18019,12 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
const std::string cmd = recv_msgs[0].to_string(); const std::string cmd = recv_msgs[0].to_string();
size_t idx = 1; size_t idx = 1;
// Handle chunk_done signal
if (cmd == "chunk_done") {
meta->chunk_done = true;
return 0;
}
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) { if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
meta->clear_kv_cache = true; meta->clear_kv_cache = true;
return 0; return 0;
@ -17953,29 +18065,82 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
return 0; return 0;
} }
if (recv_msgs.size() % 2 != 0) {
LLAMA_LOG_ERROR("Invalid message format: odd number of messages\n");
return -1;
}
for (size_t i = 0; i < recv_msgs.size(); i += 2) { for (size_t i = 0; i < recv_msgs.size(); i += 2) {
std::string key = recv_msgs[i].to_string(); if (i + 1 >= recv_msgs.size()) break;
std::string key = recv_msgs[i].to_string();
zmq::message_t & data_msg = recv_msgs[i + 1]; zmq::message_t & data_msg = recv_msgs[i + 1];
if (key == "n_tokens") { if (key == "tokens_size") {
GGML_ASSERT(data_msg.size() == sizeof(meta->tokens_size));
std::memcpy(&(meta->tokens_size), data_msg.data(), sizeof(meta->tokens_size));
}
else if (key == "n_tokens") {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens)); GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens)); std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
} }
else if (key == "n_outputs") {
if (key == "pos") { GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs));
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos)); std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs));
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
} }
// else if (key == "chunk_start_pos") {
if (key == "all_pos_0") { // GGML_ASSERT(data_msg.size() == sizeof(meta->chunk_start_pos));
// std::memcpy(&(meta->chunk_start_pos), data_msg.data(), sizeof(meta->chunk_start_pos));
// }
else if (key == "n_ctx") {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx));
std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx));
}
else if (key == "all_pos_0") {
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0)); GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0)); std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
} }
else if (key == "all_pos_1") {
if (key == "all_pos_1") {
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
} }
else if (key == "pos") {
if (meta->n_tokens > 0) {
meta->pos = (llama_pos *) malloc(meta->n_tokens * sizeof(llama_pos));
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_pos));
std::memcpy(meta->pos, data_msg.data(), meta->n_tokens * sizeof(llama_pos));
}
}
else if (key == "n_seq_id") {
if (meta->n_tokens > 0) {
meta->n_seq_id = (int32_t *) malloc(data_msg.size());
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
}
}
// batch.logits
else if (key == "seq_id") {
if (meta->n_tokens > 0) {
const int32_t n_tokens = meta->n_tokens;
const int32_t n_seq_max = 1;
GGML_ASSERT(data_msg.size() == (size_t)n_tokens * n_seq_max * sizeof(llama_seq_id));
meta->seq_id = (llama_seq_id **) malloc(n_tokens * sizeof(llama_seq_id *));
const llama_seq_id * flat_data = (const llama_seq_id *)data_msg.data();
for (int32_t token_idx = 0; token_idx < n_tokens; ++token_idx) {
meta->seq_id[token_idx] = (llama_seq_id *) malloc(n_seq_max * sizeof(llama_seq_id));
std::memcpy(meta->seq_id[token_idx], flat_data + token_idx * n_seq_max, n_seq_max * sizeof(llama_seq_id));
}
}
}
else if (key == "logits") {
if (meta->n_tokens > 0) {
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
}
}
} }
return 0; return 0;
} }
@ -18190,6 +18355,8 @@ static int llama_decode_internal(
const uint32_t n_world = cparams.n_world; const uint32_t n_world = cparams.n_world;
const uint32_t my_rank = cparams.rank; const uint32_t my_rank = cparams.rank;
const bool is_perplexity_mode = cparams.is_perplexity_eval;
const uint32_t n_tokens_all = batch_all.n_tokens; const uint32_t n_tokens_all = batch_all.n_tokens;
const int64_t n_embd = hparams.n_embd; // used for reserving embeddings space size const int64_t n_embd = hparams.n_embd; // used for reserving embeddings space size
@ -18243,73 +18410,76 @@ static int llama_decode_internal(
} }
// prepare for send and receive of metadata // prepare for send and receive of metadata
sync_meta meta; if (!is_perplexity_mode) {
meta.n_ctx = cparams.n_ctx; sync_meta meta;
bool is_last_dev = (my_rank == n_world - 1); meta.n_ctx = cparams.n_ctx;
bool is_last_dev = (my_rank == n_world - 1);
if (my_rank != 0) { if (my_rank != 0) {
if (llama_recv_meta(&lctx, &meta) == -1) { if (llama_recv_meta(&lctx, &meta, false) == -1) {
return -1; return -1;
} }
if (meta.n_tokens > 0) { if (meta.n_tokens > 0) {
batch_all.n_tokens = meta.n_tokens; batch_all.n_tokens = meta.n_tokens;
if (meta.pos != nullptr) { if (meta.pos != nullptr) {
batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos));
std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos));
}
batch_all.all_pos_0 = meta.all_pos_0;
batch_all.all_pos_1 = meta.all_pos_1;
}
if (kv_cache_op(meta.clear_kv_cache,
[&]{ llama_kv_cache_clear (&lctx); },
[&]{ llama_send_kv_cache_clear (&lctx); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_clear\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_rm,
[&]{ llama_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); },
[&]{ llama_send_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_rm\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_add,
[&]{ llama_kv_cache_seq_add (&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
[&]{ llama_send_kv_cache_seq_add(&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_add\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_cp,
[&]{ llama_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
[&]{ llama_send_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_cp\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_div,
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_div\n", __func__);
return -1;
} }
batch_all.all_pos_0 = meta.all_pos_0;
batch_all.all_pos_1 = meta.all_pos_1;
} }
if (kv_cache_op(meta.clear_kv_cache,
[&]{ llama_kv_cache_clear (&lctx); },
[&]{ llama_send_kv_cache_clear (&lctx); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_clear\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_rm, if (!is_last_dev) {
[&]{ llama_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); }, meta.n_tokens = batch_all.n_tokens;
[&]{ llama_send_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); }, meta.pos = batch_all.pos;
is_last_dev)) { meta.all_pos_0 = batch_all.all_pos_0;
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_rm\n", __func__); meta.all_pos_1 = batch_all.all_pos_1;
return -1; llama_send_meta(&lctx, &meta, false);
} }
if (kv_cache_op(meta.kv_seq_add,
[&]{ llama_kv_cache_seq_add (&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
[&]{ llama_send_kv_cache_seq_add(&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_add\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_cp,
[&]{ llama_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
[&]{ llama_send_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_cp\n", __func__);
return -1;
}
if (kv_cache_op(meta.kv_seq_div,
[&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
[&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); },
is_last_dev)) {
LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_div\n", __func__);
return -1;
}
} }
if (!is_last_dev) {
meta.n_tokens = batch_all.n_tokens;
meta.pos = batch_all.pos;
meta.all_pos_0 = batch_all.all_pos_0;
meta.all_pos_1 = batch_all.all_pos_1;
llama_send_meta(&lctx, &meta);
}
lctx.sbatch.from_batch(batch_all, n_embd, lctx.sbatch.from_batch(batch_all, n_embd,
/* simple_split */ !kv_self.recurrent, /* simple_split */ !kv_self.recurrent,
@ -20281,6 +20451,7 @@ struct llama_context_params llama_context_default_params() {
/*.no_perf =*/ true, /*.no_perf =*/ true,
/*.abort_callback =*/ nullptr, /*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr, /*.abort_callback_data =*/ nullptr,
/*.is_perplexity_mode =*/ false
}; };
return result; return result;
@ -20444,10 +20615,15 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
} }
ctx->sock_context = new zmq::context_t(2); ctx->sock_context = new zmq::context_t(2);
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
// Reverse pipeline sockets (new - for barriers)
ctx->reverse_send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
ctx->reverse_recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
if (my_rank != 0 && my_rank != (n_world - 1)) { if (my_rank != 0 && my_rank != (n_world - 1)) {
ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
} else if (my_rank == (n_world - 1)) { } else if (my_rank == (n_world - 1)) {
@ -20455,18 +20631,38 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
} }
const uint32_t next_rank = (my_rank + 1) % n_world; const uint32_t next_rank = (my_rank + 1) % n_world;
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port)); std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port));
std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port)); std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port));
std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port)); std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port));
// Reverse pipeline endpoints (new)
// Use a different port offset for reverse communication to avoid conflicts
const uint32_t reverse_port_offset = 1000;
std::string reverse_recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port + reverse_port_offset));
std::string reverse_send_endp = "tcp://" + ctx->prev_node_ip + ":" + std::to_string(map_rank_to_port(prev_rank, ctx->data_port + reverse_port_offset));
try { try {
ctx->recv_socket->bind(recv_endp); ctx->recv_socket->bind(recv_endp);
ctx->signal_socket->bind(signal_endp); ctx->signal_socket->bind(signal_endp);
ctx->send_socket->connect(send_endp); ctx->send_socket->connect(send_endp);
if (ctx->master_socket && my_rank != (n_world - 1)) { if (ctx->master_socket && my_rank != (n_world - 1)) {
ctx->master_socket->connect(master_endp); ctx->master_socket->connect(master_endp);
} }
// Setup reverse pipeline sockets
if (my_rank > 0) {
// All ranks except rank 0 can send to previous rank
ctx->reverse_send_socket->connect(reverse_send_endp);
}
if (my_rank < n_world - 1) {
// All ranks except last rank can receive from next rank
ctx->reverse_recv_socket->bind(reverse_recv_endp);
}
} catch (const zmq::error_t &e) { } catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
exit(1); exit(1);
@ -20740,6 +20936,7 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
const uint32_t my_rank = ctx->cparams.rank; const uint32_t my_rank = ctx->cparams.rank;
// to adapt to the new topology, use old next_rank // to adapt to the new topology, use old next_rank
const uint32_t next_rank = ctx->cparams.original_next_rank; const uint32_t next_rank = ctx->cparams.original_next_rank;
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
if (n_world == 1) { if (n_world == 1) {
return; return;
@ -20763,6 +20960,89 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
*msg = new char[msg_str.size() + 1]; *msg = new char[msg_str.size() + 1];
std::strcpy(*msg, msg_str.c_str()); std::strcpy(*msg, msg_str.c_str());
} }
// Send shutdown signal through reverse pipeline as well
if (my_rank == n_world - 1) {
// Last rank initiates reverse shutdown
try {
sync_meta shutdown_meta;
shutdown_meta.chunk_done = true; // Reuse chunk_done as shutdown signal
llama_send_meta(ctx, &shutdown_meta, true); // reverse = true
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error sending reverse shutdown signal: %s", e.what());
}
} else if (my_rank > 0) {
// Intermediate ranks relay reverse shutdown signal
try {
sync_meta shutdown_meta;
// Set a short timeout for shutdown
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, 500);
if (llama_recv_meta(ctx, &shutdown_meta, true) == 0) {
if (my_rank > 0) {
llama_send_meta(ctx, &shutdown_meta, true); // relay upstream
}
}
// Reset timeout
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, -1);
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error handling reverse shutdown signal on rank %d: %s", my_rank, e.what());
}
}
try {
// Close signal sender (local socket created in this function)
signal_sender.close();
// Close reverse sockets first
if (ctx->reverse_send_socket) {
ctx->reverse_send_socket->close();
delete ctx->reverse_send_socket;
ctx->reverse_send_socket = nullptr;
}
if (ctx->reverse_recv_socket) {
ctx->reverse_recv_socket->close();
delete ctx->reverse_recv_socket;
ctx->reverse_recv_socket = nullptr;
}
// Close existing forward sockets
if (ctx->send_socket) {
ctx->send_socket->close();
delete ctx->send_socket;
ctx->send_socket = nullptr;
}
if (ctx->recv_socket) {
ctx->recv_socket->close();
delete ctx->recv_socket;
ctx->recv_socket = nullptr;
}
if (ctx->signal_socket) {
ctx->signal_socket->close();
delete ctx->signal_socket;
ctx->signal_socket = nullptr;
}
// Handle master_socket cleanup (be careful not to double-delete)
if (ctx->master_socket && my_rank != (n_world - 1) && ctx->master_socket != ctx->send_socket) {
ctx->master_socket->close();
delete ctx->master_socket;
ctx->master_socket = nullptr;
}
// Cleanup ZMQ context last
if (ctx->sock_context) {
delete ctx->sock_context;
ctx->sock_context = nullptr;
}
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error cleaning up sockets: %s", e.what());
}
} }
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) { void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
@ -20789,6 +21069,7 @@ struct llama_context * llama_new_context_with_model(
ctx->cparams.rank = params.rank; ctx->cparams.rank = params.rank;
ctx->cparams.force = params.force; ctx->cparams.force = params.force;
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world; ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
ctx->cparams.is_perplexity_eval = params.is_perplexity_eval;
return ctx; return ctx;
} }