diff --git a/common/common.cpp b/common/common.cpp index dff98506..4860b6b9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1992,6 +1992,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.keep_out_in_metal = params.keep_out_in_metal; cparams.n_gpu_layers = params.n_gpu_layers; cparams.n_cycles = params.n_cycles; + cparams.is_perplexity_eval= params.is_perplexity_eval; std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), cparams.n_layer_window); if (cparams.master_ip != nullptr) { diff --git a/common/common.h b/common/common.h index 71c6e435..e1cfec21 100644 --- a/common/common.h +++ b/common/common.h @@ -178,6 +178,8 @@ struct gpt_params { int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold + bool is_perplexity_eval; + struct cpu_params cpuparams; struct cpu_params cpuparams_batch; struct cpu_params draft_cpuparams; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 7d188b56..3c32c819 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -524,9 +524,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par sync_meta meta; if (my_rank == 0) { meta.tokens_size = tokens_size; - llama_send_meta(ctx, &meta); + llama_send_meta(ctx, &meta, false); } else { - if (llama_recv_meta(ctx, &meta) == -1) { + if (llama_recv_meta(ctx, &meta, false) == -1) { LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank); return { {}, -1.0, {}, {} }; } @@ -534,7 +534,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!"); } else { tokens_size = meta.tokens_size; - llama_send_meta(ctx, &meta); + llama_send_meta(ctx, &meta, false); } } } @@ -628,19 +628,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); { + // synvhronize the KV cache clear signal across all ranks if (n_world > 1) { sync_meta clear_meta; clear_meta.clear_kv_cache = true; if (my_rank == 0) { - llama_send_meta(ctx, &clear_meta); + llama_send_meta(ctx, &clear_meta, false); } else { - if (llama_recv_meta(ctx, &clear_meta) == -1) { + if (llama_recv_meta(ctx, &clear_meta, false) == -1) { LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank); return {tokens, -1.0, {}, {}}; } if (!is_last_dev) { - llama_send_meta(ctx, &clear_meta); + llama_send_meta(ctx, &clear_meta, false); } } } @@ -648,11 +649,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // clear the KV cache llama_kv_cache_clear(ctx); - sync_meta meta; - for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); + // used for communication of the batch meta data + sync_meta meta; int n_outputs = 0; @@ -689,38 +690,42 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // restore the original token in case it was set to BOS tokens[seq_start] = token_org; } - } - - if (my_rank == 0) { - // Required batch info: Operation scale, KV cache location, Logits calculation location - meta.n_ctx = n_ctx; + + // comms: now rank 0 need to send the batch to other ranks meta.n_tokens = batch.n_tokens; meta.pos = batch.pos; + meta.n_seq_id = batch.n_seq_id; + meta.seq_id = batch.seq_id; meta.logits = batch.logits; - meta.all_pos_0 = batch.all_pos_0; - meta.all_pos_1 = batch.all_pos_1; meta.n_outputs = n_outputs; - meta.chunk_start_pos = start; - } - // other ranks need to know batch info - { if (n_world > 1) { - meta.n_ctx = n_ctx; - - if (my_rank == 0) { - llama_send_meta(ctx, &meta); - } else { - if (llama_recv_meta(ctx, &meta) == -1) { - LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); - return {tokens, -1.0, {}, {}}; - } - if (!is_last_dev) { - llama_send_meta(ctx, &meta); - } - } + llama_send_meta(ctx, &meta, false); // reverse = false } - } + } else { + if (n_world > 1) { + // comms: other ranks receive the batch meta data + if (llama_recv_meta(ctx, &meta, false) == -1) { + LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); + return {tokens, -1.0, {}, {}}; + } + + // copy the batch meta data to the llama_batch + if (meta.n_tokens > 0) { + batch.n_tokens = meta.n_tokens; + if (meta.pos) { std::memcpy(batch.pos, meta.pos, meta.n_tokens * sizeof(llama_pos)); } // use n_tokens instead of n_batch, n_tokens is the actual number of tokens in the batch + if (meta.n_seq_id) { std::memcpy(batch.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t)); } + if (meta.seq_id) { + const int32_t n_seq_max = 1; + for (int32_t i = 0; i < meta.n_tokens; ++i) { + std::memcpy(batch.seq_id[i], meta.seq_id[i], n_seq_max * sizeof(llama_seq_id)); + } + } + if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); } + } + + } + } if (llama_decode(ctx, batch)) { LOG_INF("%s : failed to eval\n", __func__); @@ -753,8 +758,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par for (int seq = 0; seq < n_seq_batch; seq++) { const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); - int chunk_start_pos = meta.chunk_start_pos; - llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first; + llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; if (!params.logits_file.empty()) { process_logits(logits_stream, n_vocab, all_logits, tokens_data, n_ctx - 1 - first, @@ -763,8 +767,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par process_logits(n_vocab, all_logits, tokens_data, n_ctx - 1 - first, workers, nll, nll2, - logit_history.data() + chunk_start_pos + seq*n_ctx + first, - prob_history.data() + chunk_start_pos + seq*n_ctx + first); + logit_history.data() + start + seq*n_ctx + first, + prob_history.data() + start + seq*n_ctx + first); } count += n_ctx - first - 1; @@ -778,8 +782,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } } + logits.clear(); + } + + if (n_world > 1) { + sync_meta done_meta; + done_meta.chunk_done = true; + + if (is_last_dev) { + // Last device sends completion signal upstream (reverse direction) + LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i); + llama_send_meta(ctx, &done_meta, true); // reverse = true + } else if (my_rank == 0) { + // Rank 0 waits for completion signal from downstream + LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i); + if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true + LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i); + return {tokens, -1.0, {}, {}}; + } + LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i); + } else { + // Intermediate ranks: receive from downstream, relay upstream + LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i); + if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true + LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i); + return {tokens, -1.0, {}, {}}; + } + LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i); + llama_send_meta(ctx, &done_meta, true); // reverse = true + } } - logits.clear(); } LOG("\n"); @@ -795,12 +827,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } else { LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); } - + llama_batch_free(batch); return {tokens, ppl, logit_history, prob_history}; } llama_batch_free(batch); - return {}; } @@ -2078,10 +2109,11 @@ int main(int argc, char ** argv) { params.n_ctx = 512; params.logits_all = true; params.escape = false; + params.is_perplexity_eval = true; if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { return 1; - } + } uint32_t n_world = params.n_world; uint32_t my_rank = params.rank; @@ -2141,7 +2173,6 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any llama_init_result llama_init = llama_init_from_gpt_params(params); - // update rank and world size if any devices removed my_rank = params.rank; n_world = params.n_world; @@ -2189,14 +2220,34 @@ int main(int argc, char ** argv) { LOG("\n"); + if (is_last_dev) { - llama_perf_context_print(ctx); write_logfile(ctx, params, model, results); + } + + if (my_rank == 0) { + llama_perf_context_print(ctx); + } + + if (n_world > 1) { + LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank); + + if (my_rank == 0) { + llama_free_sockets(ctx, nullptr); + } + + if (my_rank != 0 && signal_thread.joinable()) { + signal_thread.join(); + } + + if (stop_signal) { + LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, stop_signal); + delete[] stop_signal; + } } llama_free(ctx); llama_free_model(model); - llama_backend_free(); return 0; diff --git a/include/llama.h b/include/llama.h index c8706488..6c906873 100644 --- a/include/llama.h +++ b/include/llama.h @@ -58,12 +58,15 @@ struct sync_meta { int8_t * logits = nullptr; llama_pos * pos = nullptr; + int32_t * n_seq_id = nullptr; + llama_seq_id ** seq_id = nullptr; llama_pos all_pos_0; llama_pos all_pos_1; uint32_t n_ctx = 0; - int chunk_start_pos; - int32_t n_outputs; // Used to pass the number of logits to be outputted + // used for perplexity evaluation + int32_t n_outputs; + bool chunk_done = false; // signal that the chunk is done // signal to clear the kv cache bool clear_kv_cache= false; @@ -389,6 +392,7 @@ extern "C" { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -422,6 +426,7 @@ extern "C" { // currently works only with CPU execution ggml_abort_callback abort_callback; void * abort_callback_data; + bool is_perplexity_eval; // whether to run in perplexity evaluation mode }; // model quantization parameters @@ -502,8 +507,8 @@ extern "C" { LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank); LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg); - LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta); - LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta); + LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse); + LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse); LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set); LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); diff --git a/src/llama.cpp b/src/llama.cpp index cc4e8ab2..cdeb2f8f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2607,6 +2607,8 @@ struct llama_cparams { int n_threads; // number of threads to use for generation int n_threads_batch; // number of threads to use for batch processing + bool is_perplexity_eval; + float rope_freq_base; float rope_freq_scale; @@ -3490,6 +3492,7 @@ struct llama_context { // sockets std::string master_ip = "localhost"; std::string next_node_ip = "localhost"; + std::string prev_node_ip = "localhost"; uint32_t data_port = 9000; uint32_t signal_port = 10000; zmq::context_t * sock_context = nullptr; @@ -3497,6 +3500,9 @@ struct llama_context { zmq::socket_t * recv_socket = nullptr; zmq::socket_t * master_socket = nullptr; zmq::socket_t * signal_socket = nullptr; + // Add these for reverse communication + zmq::socket_t * reverse_send_socket = nullptr; // Reverse: Rank i -> Rank i-1 + zmq::socket_t * reverse_recv_socket = nullptr; // Reverse: Rank i <- Rank i+1 }; struct llama_lora_weight { @@ -17866,30 +17872,126 @@ struct input_tensors { ggml_tensor * inp_pos; }; -void llama_send_meta(llama_context * ctx, struct sync_meta * meta) { +void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) { GGML_ASSERT(ctx != nullptr); GGML_ASSERT(meta != nullptr); - zmq::socket_t * send_socket = ctx->send_socket; + zmq::socket_t * send_socket = reverse ? ctx->reverse_send_socket : ctx->send_socket; GGML_ASSERT(send_socket != nullptr); try { std::vector send_msgs; - GGML_ASSERT(meta->n_tokens != 0); - send_msgs.emplace_back("n_tokens", strlen("n_tokens")); - send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens)); - - if (meta->pos != nullptr) { - send_msgs.emplace_back("pos", strlen("pos")); - send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos)); + // Handle chunk_done signal + if (meta->chunk_done) { + send_msgs.emplace_back("chunk_done", strlen("chunk_done")); + send_msgs.emplace_back("1", 1); + zmq::send_multipart(*send_socket, send_msgs); + return; } - send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); - send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + if (meta->clear_kv_cache) { + send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache")); + send_msgs.emplace_back("1", 1); + return; + } - send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); - send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + if (meta->kv_seq_rm) { + send_msgs.emplace_back("kv_seq_rm", strlen("kv_seq_rm")); + send_msgs.emplace_back(&(meta->rm_seq_id), sizeof(meta->rm_seq_id)); + send_msgs.emplace_back(&(meta->rm_p0), sizeof(meta->rm_p0)); + send_msgs.emplace_back(&(meta->rm_p1), sizeof(meta->rm_p1)); + zmq::send_multipart(*send_socket, send_msgs); + return; + } + + if (meta->kv_seq_add) { + send_msgs.emplace_back("kv_seq_add", strlen("kv_seq_add")); + send_msgs.emplace_back(&(meta->add_seq_id), sizeof(meta->add_seq_id)); + send_msgs.emplace_back(&(meta->add_p0), sizeof(meta->add_p0)); + send_msgs.emplace_back(&(meta->add_p1), sizeof(meta->add_p1)); + send_msgs.emplace_back(&(meta->add_delta), sizeof(meta->add_delta)); + zmq::send_multipart(*send_socket, send_msgs); + return; + } + + if (meta->kv_seq_cp) { + send_msgs.emplace_back("kv_seq_cp", strlen("kv_seq_cp")); + send_msgs.emplace_back(&(meta->cp_src_seq_id), sizeof(meta->cp_src_seq_id)); + send_msgs.emplace_back(&(meta->cp_dst_seq_id), sizeof(meta->cp_dst_seq_id)); + send_msgs.emplace_back(&(meta->cp_p0), sizeof(meta->cp_p0)); + send_msgs.emplace_back(&(meta->cp_p1), sizeof(meta->cp_p1)); + zmq::send_multipart(*send_socket, send_msgs); + return; + } + + if (meta->kv_seq_div) { + send_msgs.emplace_back("kv_seq_div", strlen("kv_seq_div")); + send_msgs.emplace_back(&(meta->div_seq_id), sizeof(meta->div_seq_id)); + send_msgs.emplace_back(&(meta->div_p0), sizeof(meta->div_p0)); + send_msgs.emplace_back(&(meta->div_p1), sizeof(meta->div_p1)); + send_msgs.emplace_back(&(meta->div_factor), sizeof(meta->div_factor)); + zmq::send_multipart(*send_socket, send_msgs); + return; + } + + if (meta->tokens_size > 0) { + send_msgs.emplace_back("tokens_size", strlen("tokens_size")); + send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size)); + } + + if (meta->n_tokens > 0) { + send_msgs.emplace_back("n_tokens", strlen("n_tokens")); + send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens)); + + send_msgs.emplace_back("n_outputs", strlen("n_outputs")); + send_msgs.emplace_back(&(meta->n_outputs), sizeof(meta->n_outputs)); + + // send_msgs.emplace_back("chunk_start_pos", strlen("chunk_start_pos")); + // send_msgs.emplace_back(&(meta->chunk_start_pos), sizeof(meta->chunk_start_pos)); + + send_msgs.emplace_back("n_ctx", strlen("n_ctx")); + send_msgs.emplace_back(&(meta->n_ctx), sizeof(meta->n_ctx)); + + send_msgs.emplace_back("all_pos_0", strlen("all_pos_0")); + send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0)); + + send_msgs.emplace_back("all_pos_1", strlen("all_pos_1")); + send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1)); + + // batch.pos + if (meta->pos != nullptr) { + send_msgs.emplace_back("pos", strlen("pos")); + send_msgs.emplace_back(meta->pos, meta->n_tokens * sizeof(llama_pos)); + } + // batch.n_seq_id + if (meta->n_seq_id != nullptr) { + send_msgs.emplace_back("n_seq_id", strlen("n_seq_id")); + send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t)); + } + // batch.seq_id + if (meta->seq_id != nullptr) { + const int32_t n_tokens = meta->n_tokens; + const int32_t n_seq_max = 1; + + std::vector flat_seq_ids; + flat_seq_ids.reserve(n_tokens * n_seq_max); + + for (int32_t i = 0; i < n_tokens; ++i) { + for (int32_t j = 0; j < n_seq_max; ++j) { + flat_seq_ids.push_back(meta->seq_id[i][j]); + } + } + + send_msgs.emplace_back("seq_id", strlen("seq_id")); + send_msgs.emplace_back(flat_seq_ids.data(), flat_seq_ids.size() * sizeof(llama_seq_id)); + } + // batch.logits + if (meta->logits != nullptr) { + send_msgs.emplace_back("logits", strlen("logits")); + send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t)); + } + } if (!send_msgs.empty()) { zmq::send_multipart(*send_socket, send_msgs); @@ -17899,12 +18001,16 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) { } } -int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { - ctx->recv_socket->set(zmq::sockopt::rcvtimeo, 1000); +int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) { + zmq::socket_t * recv_socket = reverse ? ctx->reverse_recv_socket : ctx->recv_socket; + GGML_ASSERT(recv_socket != nullptr); + + recv_socket->set(zmq::sockopt::rcvtimeo, 1000); std::vector recv_msgs; if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) { + recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error return -1; } @@ -17913,6 +18019,12 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { const std::string cmd = recv_msgs[0].to_string(); size_t idx = 1; + // Handle chunk_done signal + if (cmd == "chunk_done") { + meta->chunk_done = true; + return 0; + } + if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) { meta->clear_kv_cache = true; return 0; @@ -17953,29 +18065,82 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { return 0; } + if (recv_msgs.size() % 2 != 0) { + LLAMA_LOG_ERROR("Invalid message format: odd number of messages\n"); + return -1; + } + for (size_t i = 0; i < recv_msgs.size(); i += 2) { - std::string key = recv_msgs[i].to_string(); + if (i + 1 >= recv_msgs.size()) break; + + std::string key = recv_msgs[i].to_string(); zmq::message_t & data_msg = recv_msgs[i + 1]; - if (key == "n_tokens") { + if (key == "tokens_size") { + GGML_ASSERT(data_msg.size() == sizeof(meta->tokens_size)); + std::memcpy(&(meta->tokens_size), data_msg.data(), sizeof(meta->tokens_size)); + } + else if (key == "n_tokens") { GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens)); std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens)); } - - if (key == "pos") { - meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos)); - std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos)); + else if (key == "n_outputs") { + GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs)); + std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs)); } - - if (key == "all_pos_0") { + // else if (key == "chunk_start_pos") { + // GGML_ASSERT(data_msg.size() == sizeof(meta->chunk_start_pos)); + // std::memcpy(&(meta->chunk_start_pos), data_msg.data(), sizeof(meta->chunk_start_pos)); + // } + else if (key == "n_ctx") { + GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx)); + std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx)); + } + else if (key == "all_pos_0") { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0)); std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0)); } - - if (key == "all_pos_1") { + else if (key == "all_pos_1") { GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1)); std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1)); } + else if (key == "pos") { + if (meta->n_tokens > 0) { + meta->pos = (llama_pos *) malloc(meta->n_tokens * sizeof(llama_pos)); + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_pos)); + std::memcpy(meta->pos, data_msg.data(), meta->n_tokens * sizeof(llama_pos)); + } + } + else if (key == "n_seq_id") { + if (meta->n_tokens > 0) { + meta->n_seq_id = (int32_t *) malloc(data_msg.size()); + std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t)); + } + } + // batch.logits + else if (key == "seq_id") { + if (meta->n_tokens > 0) { + const int32_t n_tokens = meta->n_tokens; + const int32_t n_seq_max = 1; + + GGML_ASSERT(data_msg.size() == (size_t)n_tokens * n_seq_max * sizeof(llama_seq_id)); + + meta->seq_id = (llama_seq_id **) malloc(n_tokens * sizeof(llama_seq_id *)); + + const llama_seq_id * flat_data = (const llama_seq_id *)data_msg.data(); + for (int32_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + meta->seq_id[token_idx] = (llama_seq_id *) malloc(n_seq_max * sizeof(llama_seq_id)); + std::memcpy(meta->seq_id[token_idx], flat_data + token_idx * n_seq_max, n_seq_max * sizeof(llama_seq_id)); + } + } + } + else if (key == "logits") { + if (meta->n_tokens > 0) { + GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t)); + meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t)); + std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t)); + } + } } return 0; } @@ -18190,6 +18355,8 @@ static int llama_decode_internal( const uint32_t n_world = cparams.n_world; const uint32_t my_rank = cparams.rank; + + const bool is_perplexity_mode = cparams.is_perplexity_eval; const uint32_t n_tokens_all = batch_all.n_tokens; const int64_t n_embd = hparams.n_embd; // used for reserving embeddings space size @@ -18243,73 +18410,76 @@ static int llama_decode_internal( } // prepare for send and receive of metadata - sync_meta meta; - meta.n_ctx = cparams.n_ctx; - bool is_last_dev = (my_rank == n_world - 1); + if (!is_perplexity_mode) { + sync_meta meta; + meta.n_ctx = cparams.n_ctx; + bool is_last_dev = (my_rank == n_world - 1); - if (my_rank != 0) { - if (llama_recv_meta(&lctx, &meta) == -1) { - return -1; - } - - if (meta.n_tokens > 0) { - batch_all.n_tokens = meta.n_tokens; - if (meta.pos != nullptr) { - batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); - std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); + if (my_rank != 0) { + if (llama_recv_meta(&lctx, &meta, false) == -1) { + return -1; + } + + if (meta.n_tokens > 0) { + batch_all.n_tokens = meta.n_tokens; + if (meta.pos != nullptr) { + batch_all.pos = (llama_pos *) malloc(cparams.n_ctx * sizeof(llama_pos)); + std::memcpy(batch_all.pos, meta.pos, cparams.n_ctx * sizeof(llama_pos)); + } + batch_all.all_pos_0 = meta.all_pos_0; + batch_all.all_pos_1 = meta.all_pos_1; + } + + if (kv_cache_op(meta.clear_kv_cache, + [&]{ llama_kv_cache_clear (&lctx); }, + [&]{ llama_send_kv_cache_clear (&lctx); }, + is_last_dev)) { + LLAMA_LOG_DEBUG("%s: received signal kv_cache_clear\n", __func__); + return -1; + } + + if (kv_cache_op(meta.kv_seq_rm, + [&]{ llama_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); }, + [&]{ llama_send_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); }, + is_last_dev)) { + LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_rm\n", __func__); + return -1; + } + + if (kv_cache_op(meta.kv_seq_add, + [&]{ llama_kv_cache_seq_add (&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); }, + [&]{ llama_send_kv_cache_seq_add(&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); }, + is_last_dev)) { + LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_add\n", __func__); + return -1; + } + + if (kv_cache_op(meta.kv_seq_cp, + [&]{ llama_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); }, + [&]{ llama_send_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); }, + is_last_dev)) { + LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_cp\n", __func__); + return -1; + } + + if (kv_cache_op(meta.kv_seq_div, + [&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); }, + [&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); }, + is_last_dev)) { + LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_div\n", __func__); + return -1; } - batch_all.all_pos_0 = meta.all_pos_0; - batch_all.all_pos_1 = meta.all_pos_1; } - if (kv_cache_op(meta.clear_kv_cache, - [&]{ llama_kv_cache_clear (&lctx); }, - [&]{ llama_send_kv_cache_clear (&lctx); }, - is_last_dev)) { - LLAMA_LOG_DEBUG("%s: received signal kv_cache_clear\n", __func__); - return -1; - } - if (kv_cache_op(meta.kv_seq_rm, - [&]{ llama_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); }, - [&]{ llama_send_kv_cache_seq_rm (&lctx, meta.rm_seq_id, meta.rm_p0, meta.rm_p1); }, - is_last_dev)) { - LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_rm\n", __func__); - return -1; - } - - if (kv_cache_op(meta.kv_seq_add, - [&]{ llama_kv_cache_seq_add (&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); }, - [&]{ llama_send_kv_cache_seq_add(&lctx, meta.add_seq_id, meta.add_p0, meta.add_p1, meta.add_delta); }, - is_last_dev)) { - LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_add\n", __func__); - return -1; - } - - if (kv_cache_op(meta.kv_seq_cp, - [&]{ llama_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); }, - [&]{ llama_send_kv_cache_seq_cp (&lctx, meta.cp_src_seq_id, meta.cp_dst_seq_id, meta.cp_p0, meta.cp_p1); }, - is_last_dev)) { - LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_cp\n", __func__); - return -1; - } - - if (kv_cache_op(meta.kv_seq_div, - [&]{ llama_kv_cache_seq_div (&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); }, - [&]{ llama_send_kv_cache_seq_div(&lctx, meta.div_seq_id, meta.div_p0, meta.div_p1, meta.div_factor); }, - is_last_dev)) { - LLAMA_LOG_DEBUG("%s: received signal kv_cache_seq_div\n", __func__); - return -1; - } + if (!is_last_dev) { + meta.n_tokens = batch_all.n_tokens; + meta.pos = batch_all.pos; + meta.all_pos_0 = batch_all.all_pos_0; + meta.all_pos_1 = batch_all.all_pos_1; + llama_send_meta(&lctx, &meta, false); + } } - - if (!is_last_dev) { - meta.n_tokens = batch_all.n_tokens; - meta.pos = batch_all.pos; - meta.all_pos_0 = batch_all.all_pos_0; - meta.all_pos_1 = batch_all.all_pos_1; - llama_send_meta(&lctx, &meta); - } lctx.sbatch.from_batch(batch_all, n_embd, /* simple_split */ !kv_self.recurrent, @@ -20281,6 +20451,7 @@ struct llama_context_params llama_context_default_params() { /*.no_perf =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, + /*.is_perplexity_mode =*/ false }; return result; @@ -20444,10 +20615,15 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m } ctx->sock_context = new zmq::context_t(2); + ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); + // Reverse pipeline sockets (new - for barriers) + ctx->reverse_send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); + ctx->reverse_recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); + if (my_rank != 0 && my_rank != (n_world - 1)) { ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); } else if (my_rank == (n_world - 1)) { @@ -20455,18 +20631,38 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m } const uint32_t next_rank = (my_rank + 1) % n_world; + const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world; + std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port)); std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port)); std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port)); + // Reverse pipeline endpoints (new) + // Use a different port offset for reverse communication to avoid conflicts + const uint32_t reverse_port_offset = 1000; + std::string reverse_recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port + reverse_port_offset)); + std::string reverse_send_endp = "tcp://" + ctx->prev_node_ip + ":" + std::to_string(map_rank_to_port(prev_rank, ctx->data_port + reverse_port_offset)); + try { ctx->recv_socket->bind(recv_endp); ctx->signal_socket->bind(signal_endp); + ctx->send_socket->connect(send_endp); if (ctx->master_socket && my_rank != (n_world - 1)) { ctx->master_socket->connect(master_endp); } + + // Setup reverse pipeline sockets + if (my_rank > 0) { + // All ranks except rank 0 can send to previous rank + ctx->reverse_send_socket->connect(reverse_send_endp); + } + + if (my_rank < n_world - 1) { + // All ranks except last rank can receive from next rank + ctx->reverse_recv_socket->bind(reverse_recv_endp); + } } catch (const zmq::error_t &e) { LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); exit(1); @@ -20740,6 +20936,7 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { const uint32_t my_rank = ctx->cparams.rank; // to adapt to the new topology, use old next_rank const uint32_t next_rank = ctx->cparams.original_next_rank; + const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world; if (n_world == 1) { return; @@ -20763,6 +20960,89 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { *msg = new char[msg_str.size() + 1]; std::strcpy(*msg, msg_str.c_str()); } + + // Send shutdown signal through reverse pipeline as well + if (my_rank == n_world - 1) { + // Last rank initiates reverse shutdown + try { + sync_meta shutdown_meta; + shutdown_meta.chunk_done = true; // Reuse chunk_done as shutdown signal + llama_send_meta(ctx, &shutdown_meta, true); // reverse = true + } catch (const zmq::error_t &e) { + LLAMA_LOG_INFO("Error sending reverse shutdown signal: %s", e.what()); + } + } else if (my_rank > 0) { + // Intermediate ranks relay reverse shutdown signal + try { + sync_meta shutdown_meta; + // Set a short timeout for shutdown + ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, 500); + + if (llama_recv_meta(ctx, &shutdown_meta, true) == 0) { + if (my_rank > 0) { + llama_send_meta(ctx, &shutdown_meta, true); // relay upstream + } + } + + // Reset timeout + ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, -1); + } catch (const zmq::error_t &e) { + LLAMA_LOG_INFO("Error handling reverse shutdown signal on rank %d: %s", my_rank, e.what()); + } + } + + try { + // Close signal sender (local socket created in this function) + signal_sender.close(); + + // Close reverse sockets first + if (ctx->reverse_send_socket) { + ctx->reverse_send_socket->close(); + delete ctx->reverse_send_socket; + ctx->reverse_send_socket = nullptr; + } + + if (ctx->reverse_recv_socket) { + ctx->reverse_recv_socket->close(); + delete ctx->reverse_recv_socket; + ctx->reverse_recv_socket = nullptr; + } + + // Close existing forward sockets + if (ctx->send_socket) { + ctx->send_socket->close(); + delete ctx->send_socket; + ctx->send_socket = nullptr; + } + + if (ctx->recv_socket) { + ctx->recv_socket->close(); + delete ctx->recv_socket; + ctx->recv_socket = nullptr; + } + + if (ctx->signal_socket) { + ctx->signal_socket->close(); + delete ctx->signal_socket; + ctx->signal_socket = nullptr; + } + + // Handle master_socket cleanup (be careful not to double-delete) + if (ctx->master_socket && my_rank != (n_world - 1) && ctx->master_socket != ctx->send_socket) { + ctx->master_socket->close(); + delete ctx->master_socket; + ctx->master_socket = nullptr; + } + + // Cleanup ZMQ context last + if (ctx->sock_context) { + delete ctx->sock_context; + ctx->sock_context = nullptr; + } + + } catch (const zmq::error_t &e) { + LLAMA_LOG_INFO("Error cleaning up sockets: %s", e.what()); + } } void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) { @@ -20789,6 +21069,7 @@ struct llama_context * llama_new_context_with_model( ctx->cparams.rank = params.rank; ctx->cparams.force = params.force; ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world; + ctx->cparams.is_perplexity_eval = params.is_perplexity_eval; return ctx; }