Removed some unnecessary synchronization logic and added n_chunks communication content

This commit is contained in:
leeetao 2025-06-27 07:04:10 +00:00
parent a3becb586a
commit 48b7f53abb
4 changed files with 97 additions and 218 deletions

View file

@ -1879,7 +1879,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
params.sparams.ignore_eos = false;
}
if (params.warmup) {
if (0) {
LOG_WRN("%s: warming up the model with an empty run - please wait ...\n", __func__);
const uint32_t my_rank = cparams.rank;
@ -2006,7 +2006,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
}
cparams.next_node_ip = new char[params.next_node_ip.length() + 1];
std::strcpy(cparams.next_node_ip, params.next_node_ip.c_str());
cparams.n_ctx = params.n_ctx;
cparams.n_predict = params.n_predict;
cparams.n_seq_max = params.n_parallel;

View file

@ -489,9 +489,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
// only last device store logits file
std::ofstream logits_stream;
if (my_rank == n_world - 1){
if (my_rank == 0) {
if (!params.logits_file.empty()) {
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
if (!logits_stream.is_open()) {
@ -506,9 +505,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<llama_token> tokens;
size_t tokens_size = 0;
int n_chunks = params.n_chunks;
// maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes
if (my_rank == 0 || is_last_dev) {
if (my_rank == 0) {
auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank);
@ -519,26 +518,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
}
{
if (n_world > 1) {
sync_meta meta;
if (my_rank == 0) {
if (my_rank != 0) {
LOG_INF("perplexity: rank %d waiting for rank 0 to be ready\n", my_rank);
}
if (n_world > 1) {
sync_meta meta;
if (my_rank == 0) {
meta.tokens_size = tokens_size;
llama_send_meta(ctx, &meta, false);
} else {
if (llama_recv_meta(ctx, &meta, false) == -1) {
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
meta.n_chunks = params.n_chunks;
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
llama_send_meta(ctx, &meta);
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
} else {
LOG_INF("%s: rank %d waiting 5 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
if (llama_recv_meta(ctx, &meta) == -1) {
return { {}, -1.0, {}, {} };
}
if (is_last_dev) {
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
} else {
tokens_size = meta.tokens_size;
llama_send_meta(ctx, &meta, false);
}
tokens_size = meta.tokens_size;
n_chunks = meta.n_chunks;
if (!is_last_dev) {
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
llama_send_meta(ctx, &meta);
}
}
}
LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size);
if (my_rank == 0) {
@ -553,14 +562,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> logit_history;
std::vector<float> prob_history;
if (is_last_dev) {
if (my_rank == 0) {
logit_history.resize(tokens_size);
prob_history.resize(tokens_size);
}
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk_max = tokens_size / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_chunk = n_chunks < 0 ? n_chunk_max : std::min(n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
@ -578,9 +587,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> logits;
if(is_last_dev){
if((my_rank == 0 || is_last_dev)){
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
LOG_INF("%s: rank %d reserved logits space for %zu elements\n",
__func__, my_rank, (size_t)n_ctx * n_vocab);
}
}
@ -590,9 +601,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<uint16_t> log_probs; // the log probabilities of logits
// rank 0 and last rank store log_probs
// only rank 0 or last device stores logits/log_probs
if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) {
// only rank 0 stores logits/log_probs
if (!params.logits_file.empty() && (my_rank == 0)) {
const int nv = 2*((n_vocab + 1)/2) + 4;
log_probs.resize(n_ctx * nv);
@ -634,14 +644,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
clear_meta.clear_kv_cache = true;
if (my_rank == 0) {
llama_send_meta(ctx, &clear_meta, false);
llama_send_meta(ctx, &clear_meta);
} else {
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
if (llama_recv_meta(ctx, &clear_meta) == -1) {
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
if (!is_last_dev) {
llama_send_meta(ctx, &clear_meta, false);
llama_send_meta(ctx, &clear_meta);
}
}
}
@ -700,12 +710,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
meta.n_outputs = n_outputs;
if (n_world > 1) {
llama_send_meta(ctx, &meta, false); // reverse = false
llama_send_meta(ctx, &meta);
}
} else {
if (n_world > 1) {
// comms: other ranks receive the batch meta data
if (llama_recv_meta(ctx, &meta, false) == -1) {
if (llama_recv_meta(ctx, &meta) == -1) {
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
return {tokens, -1.0, {}, {}};
}
@ -732,7 +742,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
return {tokens, -1, logit_history, prob_history};
}
if (is_last_dev && num_batches > 1 ) {
if (my_rank == 0 && num_batches > 1 && n_outputs > 0) {
const int n_outputs_synced = meta.n_outputs;
if (n_outputs_synced > 0) {
const auto * batch_logits = llama_get_logits(ctx);
@ -754,7 +764,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
LOG("%.2f minutes\n", total_seconds / 60.0);
}
if (is_last_dev) {
if (my_rank == 0) {
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
@ -785,37 +795,22 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
logits.clear();
}
if (n_world > 1) {
sync_meta done_meta;
done_meta.chunk_done = true;
if (my_rank == 0) {
double current_ppl = std::exp(nll / count);
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
if (is_last_dev) {
// Last device sends completion signal upstream (reverse direction)
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
} else if (my_rank == 0) {
// Rank 0 waits for completion signal from downstream
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
} else {
// Intermediate ranks: receive from downstream, relay upstream
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
return {tokens, -1.0, {}, {}};
}
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
llama_send_meta(ctx, &done_meta, true); // reverse = true
}
LOG_INF("Rank 0: Chunk %d/%d (%.1f%%) completed, current_ppl = %.4lf\n",
i + n_seq_batch, n_chunk, progress, current_ppl);
} else {
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
LOG_INF("Rank %d: Chunk %d/%d (%.1f%%) completed\n",
my_rank, i + n_seq_batch, n_chunk, progress);
}
}
LOG("\n");
if (is_last_dev) {
if (my_rank == 0) {
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
@ -2221,19 +2216,22 @@ int main(int argc, char ** argv) {
LOG("\n");
if (is_last_dev) {
write_logfile(ctx, params, model, results);
}
if (my_rank == 0) {
write_logfile(ctx, params, model, results);
llama_perf_context_print(ctx);
}
}
if (n_world > 1) {
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
if (my_rank == 0) {
llama_free_sockets(ctx, nullptr);
char * rank0_stop_signal = nullptr;
llama_free_sockets(ctx, &rank0_stop_signal);
if (rank0_stop_signal) {
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, rank0_stop_signal);
delete[] rank0_stop_signal;
}
}
if (my_rank != 0 && signal_thread.joinable()) {

View file

@ -62,11 +62,11 @@ struct sync_meta {
llama_seq_id ** seq_id = nullptr;
llama_pos all_pos_0;
llama_pos all_pos_1;
uint32_t n_ctx = 0;
// used for perplexity evaluation
int32_t n_outputs;
bool chunk_done = false; // signal that the chunk is done
// signal to clear the kv cache
bool clear_kv_cache= false;
@ -98,8 +98,9 @@ struct sync_meta {
llama_pos div_p1 = 0;
int div_factor = 1;
// signal to transfer tokens_size
// perplexity evaluation
size_t tokens_size = 0;
int n_chunks = -1;
};
#ifdef __cplusplus
@ -507,8 +508,8 @@ extern "C" {
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta);
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta);
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);

View file

@ -3492,8 +3492,7 @@ struct llama_context {
// sockets
std::string master_ip = "localhost";
std::string next_node_ip = "localhost";
std::string prev_node_ip = "localhost";
uint32_t data_port = 9000;
uint32_t data_port = 9043;
uint32_t signal_port = 10000;
zmq::context_t * sock_context = nullptr;
zmq::socket_t * send_socket = nullptr;
@ -17872,27 +17871,33 @@ struct input_tensors {
ggml_tensor * inp_pos;
};
void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(meta != nullptr);
zmq::socket_t * send_socket = reverse ? ctx->reverse_send_socket : ctx->send_socket;
zmq::socket_t * send_socket = ctx->send_socket;
GGML_ASSERT(send_socket != nullptr);
try {
std::vector<zmq::message_t> send_msgs;
// Handle chunk_done signal
if (meta->chunk_done) {
send_msgs.emplace_back("chunk_done", strlen("chunk_done"));
if (meta->clear_kv_cache) {
send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache"));
send_msgs.emplace_back("1", 1);
zmq::send_multipart(*send_socket, send_msgs);
return;
}
if (meta->clear_kv_cache) {
send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache"));
send_msgs.emplace_back("1", 1);
if (meta->tokens_size > 0) {
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
if (meta->n_chunks >= 0) {
send_msgs.emplace_back("n_chunks", strlen("n_chunks"));
send_msgs.emplace_back(&(meta->n_chunks), sizeof(meta->n_chunks));
}
zmq::send_multipart(*send_socket, send_msgs);
return;
}
@ -17935,11 +17940,6 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse
return;
}
if (meta->tokens_size > 0) {
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
}
if (meta->n_tokens > 0) {
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
@ -18001,35 +18001,30 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse
}
}
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
zmq::socket_t * recv_socket = reverse ? ctx->reverse_recv_socket : ctx->recv_socket;
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
zmq::socket_t * recv_socket = ctx->recv_socket;
GGML_ASSERT(recv_socket != nullptr);
recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
std::vector<zmq::message_t> recv_msgs;
if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) {
if (!zmq::recv_multipart(*recv_socket, std::back_inserter(recv_msgs))) {
recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error
return -1;
}
ctx->recv_socket->set(zmq::sockopt::rcvtimeo, -1);
recv_socket->set(zmq::sockopt::rcvtimeo, -1);
const std::string cmd = recv_msgs[0].to_string();
size_t idx = 1;
// Handle chunk_done signal
if (cmd == "chunk_done") {
meta->chunk_done = true;
return 0;
}
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
if (cmd == "clear_kv_cache" && recv_msgs.size() == 2) {
meta->clear_kv_cache = true;
return 0;
}
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
meta->kv_seq_rm = true;
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
@ -18076,22 +18071,17 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse =
std::string key = recv_msgs[i].to_string();
zmq::message_t & data_msg = recv_msgs[i + 1];
if (key == "tokens_size") {
GGML_ASSERT(data_msg.size() == sizeof(meta->tokens_size));
std::memcpy(&(meta->tokens_size), data_msg.data(), sizeof(meta->tokens_size));
}
else if (key == "n_tokens") {
if (key == "n_tokens") {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
} else if (key == "n_chunks") {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_chunks));
std::memcpy(&(meta->n_chunks), data_msg.data(), sizeof(meta->n_chunks));
}
else if (key == "n_outputs") {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs));
std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs));
}
// else if (key == "chunk_start_pos") {
// GGML_ASSERT(data_msg.size() == sizeof(meta->chunk_start_pos));
// std::memcpy(&(meta->chunk_start_pos), data_msg.data(), sizeof(meta->chunk_start_pos));
// }
else if (key == "n_ctx") {
GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx));
std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx));
@ -18416,7 +18406,7 @@ static int llama_decode_internal(
bool is_last_dev = (my_rank == n_world - 1);
if (my_rank != 0) {
if (llama_recv_meta(&lctx, &meta, false) == -1) {
if (llama_recv_meta(&lctx, &meta) == -1) {
return -1;
}
@ -18477,7 +18467,7 @@ static int llama_decode_internal(
meta.pos = batch_all.pos;
meta.all_pos_0 = batch_all.all_pos_0;
meta.all_pos_1 = batch_all.all_pos_1;
llama_send_meta(&lctx, &meta, false);
llama_send_meta(&lctx, &meta);
}
}
@ -20615,15 +20605,10 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
}
ctx->sock_context = new zmq::context_t(2);
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
// Reverse pipeline sockets (new - for barriers)
ctx->reverse_send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
ctx->reverse_recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
if (my_rank != 0 && my_rank != (n_world - 1)) {
ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
} else if (my_rank == (n_world - 1)) {
@ -20631,38 +20616,18 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
}
const uint32_t next_rank = (my_rank + 1) % n_world;
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port));
std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port));
std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port));
// Reverse pipeline endpoints (new)
// Use a different port offset for reverse communication to avoid conflicts
const uint32_t reverse_port_offset = 1000;
std::string reverse_recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port + reverse_port_offset));
std::string reverse_send_endp = "tcp://" + ctx->prev_node_ip + ":" + std::to_string(map_rank_to_port(prev_rank, ctx->data_port + reverse_port_offset));
try {
ctx->recv_socket->bind(recv_endp);
ctx->signal_socket->bind(signal_endp);
ctx->send_socket->connect(send_endp);
if (ctx->master_socket && my_rank != (n_world - 1)) {
ctx->master_socket->connect(master_endp);
}
// Setup reverse pipeline sockets
if (my_rank > 0) {
// All ranks except rank 0 can send to previous rank
ctx->reverse_send_socket->connect(reverse_send_endp);
}
if (my_rank < n_world - 1) {
// All ranks except last rank can receive from next rank
ctx->reverse_recv_socket->bind(reverse_recv_endp);
}
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
exit(1);
@ -20936,7 +20901,6 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
const uint32_t my_rank = ctx->cparams.rank;
// to adapt to the new topology, use old next_rank
const uint32_t next_rank = ctx->cparams.original_next_rank;
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
if (n_world == 1) {
return;
@ -20960,89 +20924,6 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
*msg = new char[msg_str.size() + 1];
std::strcpy(*msg, msg_str.c_str());
}
// Send shutdown signal through reverse pipeline as well
if (my_rank == n_world - 1) {
// Last rank initiates reverse shutdown
try {
sync_meta shutdown_meta;
shutdown_meta.chunk_done = true; // Reuse chunk_done as shutdown signal
llama_send_meta(ctx, &shutdown_meta, true); // reverse = true
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error sending reverse shutdown signal: %s", e.what());
}
} else if (my_rank > 0) {
// Intermediate ranks relay reverse shutdown signal
try {
sync_meta shutdown_meta;
// Set a short timeout for shutdown
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, 500);
if (llama_recv_meta(ctx, &shutdown_meta, true) == 0) {
if (my_rank > 0) {
llama_send_meta(ctx, &shutdown_meta, true); // relay upstream
}
}
// Reset timeout
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, -1);
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error handling reverse shutdown signal on rank %d: %s", my_rank, e.what());
}
}
try {
// Close signal sender (local socket created in this function)
signal_sender.close();
// Close reverse sockets first
if (ctx->reverse_send_socket) {
ctx->reverse_send_socket->close();
delete ctx->reverse_send_socket;
ctx->reverse_send_socket = nullptr;
}
if (ctx->reverse_recv_socket) {
ctx->reverse_recv_socket->close();
delete ctx->reverse_recv_socket;
ctx->reverse_recv_socket = nullptr;
}
// Close existing forward sockets
if (ctx->send_socket) {
ctx->send_socket->close();
delete ctx->send_socket;
ctx->send_socket = nullptr;
}
if (ctx->recv_socket) {
ctx->recv_socket->close();
delete ctx->recv_socket;
ctx->recv_socket = nullptr;
}
if (ctx->signal_socket) {
ctx->signal_socket->close();
delete ctx->signal_socket;
ctx->signal_socket = nullptr;
}
// Handle master_socket cleanup (be careful not to double-delete)
if (ctx->master_socket && my_rank != (n_world - 1) && ctx->master_socket != ctx->send_socket) {
ctx->master_socket->close();
delete ctx->master_socket;
ctx->master_socket = nullptr;
}
// Cleanup ZMQ context last
if (ctx->sock_context) {
delete ctx->sock_context;
ctx->sock_context = nullptr;
}
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error cleaning up sockets: %s", e.what());
}
}
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {