diff --git a/common/common.cpp b/common/common.cpp index 4860b6b9..54fcaee9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1879,7 +1879,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { params.sparams.ignore_eos = false; } - if (params.warmup) { + if (0) { LOG_WRN("%s: warming up the model with an empty run - please wait ...\n", __func__); const uint32_t my_rank = cparams.rank; @@ -2006,7 +2006,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param } cparams.next_node_ip = new char[params.next_node_ip.length() + 1]; std::strcpy(cparams.next_node_ip, params.next_node_ip.c_str()); - cparams.n_ctx = params.n_ctx; cparams.n_predict = params.n_predict; cparams.n_seq_max = params.n_parallel; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 3c32c819..9d04cfc0 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -489,9 +489,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); - // only last device store logits file std::ofstream logits_stream; - if (my_rank == n_world - 1){ + if (my_rank == 0) { if (!params.logits_file.empty()) { logits_stream.open(params.logits_file.c_str(), std::ios::binary); if (!logits_stream.is_open()) { @@ -506,9 +505,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector tokens; size_t tokens_size = 0; + int n_chunks = params.n_chunks; - // maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes - if (my_rank == 0 || is_last_dev) { + if (my_rank == 0) { auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank); @@ -519,26 +518,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast(tim2-tim1).count()); } - { - if (n_world > 1) { - sync_meta meta; - if (my_rank == 0) { + if (my_rank != 0) { + LOG_INF("perplexity: rank %d waiting for rank 0 to be ready\n", my_rank); + } + + if (n_world > 1) { + sync_meta meta; + + if (my_rank == 0) { meta.tokens_size = tokens_size; - llama_send_meta(ctx, &meta, false); - } else { - if (llama_recv_meta(ctx, &meta, false) == -1) { - LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank); + meta.n_chunks = params.n_chunks; + + LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size); + llama_send_meta(ctx, &meta); + LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__); + } else { + LOG_INF("%s: rank %d waiting 5 seconds for rank 0 to complete tokenization\n", __func__, my_rank); + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank); + if (llama_recv_meta(ctx, &meta) == -1) { return { {}, -1.0, {}, {} }; } - if (is_last_dev) { - GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!"); - } else { - tokens_size = meta.tokens_size; - llama_send_meta(ctx, &meta, false); - } + tokens_size = meta.tokens_size; + n_chunks = meta.n_chunks; + if (!is_last_dev) { + LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank); + llama_send_meta(ctx, &meta); } } } + LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size); if (my_rank == 0) { @@ -553,14 +562,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector logit_history; std::vector prob_history; - if (is_last_dev) { + if (my_rank == 0) { logit_history.resize(tokens_size); prob_history.resize(tokens_size); } - const int n_chunk_max = tokens.size() / n_ctx; + const int n_chunk_max = tokens_size / n_ctx; - const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); + const int n_chunk = n_chunks < 0 ? n_chunk_max : std::min(n_chunks, n_chunk_max); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; @@ -578,9 +587,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector logits; - if(is_last_dev){ + if((my_rank == 0 || is_last_dev)){ if (num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); + LOG_INF("%s: rank %d reserved logits space for %zu elements\n", + __func__, my_rank, (size_t)n_ctx * n_vocab); } } @@ -590,9 +601,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector log_probs; // the log probabilities of logits - // rank 0 and last rank store log_probs - // only rank 0 or last device stores logits/log_probs - if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) { + // only rank 0 stores logits/log_probs + if (!params.logits_file.empty() && (my_rank == 0)) { const int nv = 2*((n_vocab + 1)/2) + 4; log_probs.resize(n_ctx * nv); @@ -634,14 +644,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par clear_meta.clear_kv_cache = true; if (my_rank == 0) { - llama_send_meta(ctx, &clear_meta, false); + llama_send_meta(ctx, &clear_meta); } else { - if (llama_recv_meta(ctx, &clear_meta, false) == -1) { + if (llama_recv_meta(ctx, &clear_meta) == -1) { LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank); return {tokens, -1.0, {}, {}}; } if (!is_last_dev) { - llama_send_meta(ctx, &clear_meta, false); + llama_send_meta(ctx, &clear_meta); } } } @@ -700,12 +710,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par meta.n_outputs = n_outputs; if (n_world > 1) { - llama_send_meta(ctx, &meta, false); // reverse = false + llama_send_meta(ctx, &meta); } } else { if (n_world > 1) { // comms: other ranks receive the batch meta data - if (llama_recv_meta(ctx, &meta, false) == -1) { + if (llama_recv_meta(ctx, &meta) == -1) { LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); return {tokens, -1.0, {}, {}}; } @@ -732,7 +742,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, -1, logit_history, prob_history}; } - if (is_last_dev && num_batches > 1 ) { + if (my_rank == 0 && num_batches > 1 && n_outputs > 0) { const int n_outputs_synced = meta.n_outputs; if (n_outputs_synced > 0) { const auto * batch_logits = llama_get_logits(ctx); @@ -754,7 +764,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par LOG("%.2f minutes\n", total_seconds / 60.0); } - if (is_last_dev) { + if (my_rank == 0) { for (int seq = 0; seq < n_seq_batch; seq++) { const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); @@ -785,37 +795,22 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par logits.clear(); } - if (n_world > 1) { - sync_meta done_meta; - done_meta.chunk_done = true; + if (my_rank == 0) { + double current_ppl = std::exp(nll / count); + double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0; - if (is_last_dev) { - // Last device sends completion signal upstream (reverse direction) - LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i); - llama_send_meta(ctx, &done_meta, true); // reverse = true - } else if (my_rank == 0) { - // Rank 0 waits for completion signal from downstream - LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i); - if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true - LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i); - return {tokens, -1.0, {}, {}}; - } - LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i); - } else { - // Intermediate ranks: receive from downstream, relay upstream - LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i); - if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true - LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i); - return {tokens, -1.0, {}, {}}; - } - LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i); - llama_send_meta(ctx, &done_meta, true); // reverse = true - } + LOG_INF("Rank 0: Chunk %d/%d (%.1f%%) completed, current_ppl = %.4lf\n", + i + n_seq_batch, n_chunk, progress, current_ppl); + } else { + double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0; + + LOG_INF("Rank %d: Chunk %d/%d (%.1f%%) completed\n", + my_rank, i + n_seq_batch, n_chunk, progress); } } LOG("\n"); - if (is_last_dev) { + if (my_rank == 0) { nll2 /= count; nll /= count; const double ppl = exp(nll); @@ -2221,19 +2216,22 @@ int main(int argc, char ** argv) { LOG("\n"); - if (is_last_dev) { - write_logfile(ctx, params, model, results); - } - if (my_rank == 0) { + write_logfile(ctx, params, model, results); llama_perf_context_print(ctx); - } + } if (n_world > 1) { LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank); if (my_rank == 0) { - llama_free_sockets(ctx, nullptr); + char * rank0_stop_signal = nullptr; + llama_free_sockets(ctx, &rank0_stop_signal); + + if (rank0_stop_signal) { + LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, rank0_stop_signal); + delete[] rank0_stop_signal; + } } if (my_rank != 0 && signal_thread.joinable()) { diff --git a/include/llama.h b/include/llama.h index 6c906873..136c44e0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -62,11 +62,11 @@ struct sync_meta { llama_seq_id ** seq_id = nullptr; llama_pos all_pos_0; llama_pos all_pos_1; + uint32_t n_ctx = 0; // used for perplexity evaluation int32_t n_outputs; - bool chunk_done = false; // signal that the chunk is done // signal to clear the kv cache bool clear_kv_cache= false; @@ -98,8 +98,9 @@ struct sync_meta { llama_pos div_p1 = 0; int div_factor = 1; - // signal to transfer tokens_size + // perplexity evaluation size_t tokens_size = 0; + int n_chunks = -1; }; #ifdef __cplusplus @@ -507,8 +508,8 @@ extern "C" { LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank); LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg); - LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse); - LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse); + LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta); + LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta); LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set); LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); diff --git a/src/llama.cpp b/src/llama.cpp index cdeb2f8f..30645576 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3492,8 +3492,7 @@ struct llama_context { // sockets std::string master_ip = "localhost"; std::string next_node_ip = "localhost"; - std::string prev_node_ip = "localhost"; - uint32_t data_port = 9000; + uint32_t data_port = 9043; uint32_t signal_port = 10000; zmq::context_t * sock_context = nullptr; zmq::socket_t * send_socket = nullptr; @@ -17872,27 +17871,33 @@ struct input_tensors { ggml_tensor * inp_pos; }; -void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) { +void llama_send_meta(llama_context * ctx, struct sync_meta * meta) { GGML_ASSERT(ctx != nullptr); GGML_ASSERT(meta != nullptr); - zmq::socket_t * send_socket = reverse ? ctx->reverse_send_socket : ctx->send_socket; + zmq::socket_t * send_socket = ctx->send_socket; GGML_ASSERT(send_socket != nullptr); try { std::vector send_msgs; - // Handle chunk_done signal - if (meta->chunk_done) { - send_msgs.emplace_back("chunk_done", strlen("chunk_done")); + if (meta->clear_kv_cache) { + send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache")); send_msgs.emplace_back("1", 1); zmq::send_multipart(*send_socket, send_msgs); return; } - - if (meta->clear_kv_cache) { - send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache")); - send_msgs.emplace_back("1", 1); + + if (meta->tokens_size > 0) { + send_msgs.emplace_back("tokens_size", strlen("tokens_size")); + send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size)); + + if (meta->n_chunks >= 0) { + send_msgs.emplace_back("n_chunks", strlen("n_chunks")); + send_msgs.emplace_back(&(meta->n_chunks), sizeof(meta->n_chunks)); + } + + zmq::send_multipart(*send_socket, send_msgs); return; } @@ -17935,11 +17940,6 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse return; } - if (meta->tokens_size > 0) { - send_msgs.emplace_back("tokens_size", strlen("tokens_size")); - send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size)); - } - if (meta->n_tokens > 0) { send_msgs.emplace_back("n_tokens", strlen("n_tokens")); send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens)); @@ -18001,35 +18001,30 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse } } -int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) { - zmq::socket_t * recv_socket = reverse ? ctx->reverse_recv_socket : ctx->recv_socket; +int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { + zmq::socket_t * recv_socket = ctx->recv_socket; GGML_ASSERT(recv_socket != nullptr); - recv_socket->set(zmq::sockopt::rcvtimeo, 1000); std::vector recv_msgs; - if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) { + if (!zmq::recv_multipart(*recv_socket, std::back_inserter(recv_msgs))) { recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error return -1; } - ctx->recv_socket->set(zmq::sockopt::rcvtimeo, -1); + recv_socket->set(zmq::sockopt::rcvtimeo, -1); const std::string cmd = recv_msgs[0].to_string(); size_t idx = 1; - // Handle chunk_done signal - if (cmd == "chunk_done") { - meta->chunk_done = true; - return 0; - } - - if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) { + if (cmd == "clear_kv_cache" && recv_msgs.size() == 2) { meta->clear_kv_cache = true; return 0; } + + if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) { meta->kv_seq_rm = true; std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id)); @@ -18076,22 +18071,17 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = std::string key = recv_msgs[i].to_string(); zmq::message_t & data_msg = recv_msgs[i + 1]; - if (key == "tokens_size") { - GGML_ASSERT(data_msg.size() == sizeof(meta->tokens_size)); - std::memcpy(&(meta->tokens_size), data_msg.data(), sizeof(meta->tokens_size)); - } - else if (key == "n_tokens") { + if (key == "n_tokens") { GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens)); std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens)); + } else if (key == "n_chunks") { + GGML_ASSERT(data_msg.size() == sizeof(meta->n_chunks)); + std::memcpy(&(meta->n_chunks), data_msg.data(), sizeof(meta->n_chunks)); } else if (key == "n_outputs") { GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs)); std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs)); } - // else if (key == "chunk_start_pos") { - // GGML_ASSERT(data_msg.size() == sizeof(meta->chunk_start_pos)); - // std::memcpy(&(meta->chunk_start_pos), data_msg.data(), sizeof(meta->chunk_start_pos)); - // } else if (key == "n_ctx") { GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx)); std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx)); @@ -18416,7 +18406,7 @@ static int llama_decode_internal( bool is_last_dev = (my_rank == n_world - 1); if (my_rank != 0) { - if (llama_recv_meta(&lctx, &meta, false) == -1) { + if (llama_recv_meta(&lctx, &meta) == -1) { return -1; } @@ -18477,7 +18467,7 @@ static int llama_decode_internal( meta.pos = batch_all.pos; meta.all_pos_0 = batch_all.all_pos_0; meta.all_pos_1 = batch_all.all_pos_1; - llama_send_meta(&lctx, &meta, false); + llama_send_meta(&lctx, &meta); } } @@ -20615,15 +20605,10 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m } ctx->sock_context = new zmq::context_t(2); - ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); - // Reverse pipeline sockets (new - for barriers) - ctx->reverse_send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); - ctx->reverse_recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull); - if (my_rank != 0 && my_rank != (n_world - 1)) { ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); } else if (my_rank == (n_world - 1)) { @@ -20631,38 +20616,18 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m } const uint32_t next_rank = (my_rank + 1) % n_world; - const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world; - std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port)); std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port)); std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port)); - // Reverse pipeline endpoints (new) - // Use a different port offset for reverse communication to avoid conflicts - const uint32_t reverse_port_offset = 1000; - std::string reverse_recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port + reverse_port_offset)); - std::string reverse_send_endp = "tcp://" + ctx->prev_node_ip + ":" + std::to_string(map_rank_to_port(prev_rank, ctx->data_port + reverse_port_offset)); - try { ctx->recv_socket->bind(recv_endp); ctx->signal_socket->bind(signal_endp); - ctx->send_socket->connect(send_endp); if (ctx->master_socket && my_rank != (n_world - 1)) { ctx->master_socket->connect(master_endp); } - - // Setup reverse pipeline sockets - if (my_rank > 0) { - // All ranks except rank 0 can send to previous rank - ctx->reverse_send_socket->connect(reverse_send_endp); - } - - if (my_rank < n_world - 1) { - // All ranks except last rank can receive from next rank - ctx->reverse_recv_socket->bind(reverse_recv_endp); - } } catch (const zmq::error_t &e) { LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); exit(1); @@ -20936,7 +20901,6 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { const uint32_t my_rank = ctx->cparams.rank; // to adapt to the new topology, use old next_rank const uint32_t next_rank = ctx->cparams.original_next_rank; - const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world; if (n_world == 1) { return; @@ -20960,89 +20924,6 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { *msg = new char[msg_str.size() + 1]; std::strcpy(*msg, msg_str.c_str()); } - - // Send shutdown signal through reverse pipeline as well - if (my_rank == n_world - 1) { - // Last rank initiates reverse shutdown - try { - sync_meta shutdown_meta; - shutdown_meta.chunk_done = true; // Reuse chunk_done as shutdown signal - llama_send_meta(ctx, &shutdown_meta, true); // reverse = true - } catch (const zmq::error_t &e) { - LLAMA_LOG_INFO("Error sending reverse shutdown signal: %s", e.what()); - } - } else if (my_rank > 0) { - // Intermediate ranks relay reverse shutdown signal - try { - sync_meta shutdown_meta; - // Set a short timeout for shutdown - ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, 500); - - if (llama_recv_meta(ctx, &shutdown_meta, true) == 0) { - if (my_rank > 0) { - llama_send_meta(ctx, &shutdown_meta, true); // relay upstream - } - } - - // Reset timeout - ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, -1); - } catch (const zmq::error_t &e) { - LLAMA_LOG_INFO("Error handling reverse shutdown signal on rank %d: %s", my_rank, e.what()); - } - } - - try { - // Close signal sender (local socket created in this function) - signal_sender.close(); - - // Close reverse sockets first - if (ctx->reverse_send_socket) { - ctx->reverse_send_socket->close(); - delete ctx->reverse_send_socket; - ctx->reverse_send_socket = nullptr; - } - - if (ctx->reverse_recv_socket) { - ctx->reverse_recv_socket->close(); - delete ctx->reverse_recv_socket; - ctx->reverse_recv_socket = nullptr; - } - - // Close existing forward sockets - if (ctx->send_socket) { - ctx->send_socket->close(); - delete ctx->send_socket; - ctx->send_socket = nullptr; - } - - if (ctx->recv_socket) { - ctx->recv_socket->close(); - delete ctx->recv_socket; - ctx->recv_socket = nullptr; - } - - if (ctx->signal_socket) { - ctx->signal_socket->close(); - delete ctx->signal_socket; - ctx->signal_socket = nullptr; - } - - // Handle master_socket cleanup (be careful not to double-delete) - if (ctx->master_socket && my_rank != (n_world - 1) && ctx->master_socket != ctx->send_socket) { - ctx->master_socket->close(); - delete ctx->master_socket; - ctx->master_socket = nullptr; - } - - // Cleanup ZMQ context last - if (ctx->sock_context) { - delete ctx->sock_context; - ctx->sock_context = nullptr; - } - - } catch (const zmq::error_t &e) { - LLAMA_LOG_INFO("Error cleaning up sockets: %s", e.what()); - } } void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {