mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 23:19:18 +00:00
Removed some unnecessary synchronization logic and added n_chunks communication content
This commit is contained in:
parent
a3becb586a
commit
48b7f53abb
4 changed files with 97 additions and 218 deletions
|
@ -1879,7 +1879,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||||
params.sparams.ignore_eos = false;
|
params.sparams.ignore_eos = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.warmup) {
|
if (0) {
|
||||||
LOG_WRN("%s: warming up the model with an empty run - please wait ...\n", __func__);
|
LOG_WRN("%s: warming up the model with an empty run - please wait ...\n", __func__);
|
||||||
|
|
||||||
const uint32_t my_rank = cparams.rank;
|
const uint32_t my_rank = cparams.rank;
|
||||||
|
@ -2006,7 +2006,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
}
|
}
|
||||||
cparams.next_node_ip = new char[params.next_node_ip.length() + 1];
|
cparams.next_node_ip = new char[params.next_node_ip.length() + 1];
|
||||||
std::strcpy(cparams.next_node_ip, params.next_node_ip.c_str());
|
std::strcpy(cparams.next_node_ip, params.next_node_ip.c_str());
|
||||||
|
|
||||||
cparams.n_ctx = params.n_ctx;
|
cparams.n_ctx = params.n_ctx;
|
||||||
cparams.n_predict = params.n_predict;
|
cparams.n_predict = params.n_predict;
|
||||||
cparams.n_seq_max = params.n_parallel;
|
cparams.n_seq_max = params.n_parallel;
|
||||||
|
|
|
@ -489,9 +489,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
||||||
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
||||||
|
|
||||||
// only last device store logits file
|
|
||||||
std::ofstream logits_stream;
|
std::ofstream logits_stream;
|
||||||
if (my_rank == n_world - 1){
|
if (my_rank == 0) {
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
|
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
|
||||||
if (!logits_stream.is_open()) {
|
if (!logits_stream.is_open()) {
|
||||||
|
@ -506,9 +505,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
size_t tokens_size = 0;
|
size_t tokens_size = 0;
|
||||||
|
int n_chunks = params.n_chunks;
|
||||||
|
|
||||||
// maybe we need to try other solutions, such as direct communication of tokens between the head and tail nodes
|
if (my_rank == 0) {
|
||||||
if (my_rank == 0 || is_last_dev) {
|
|
||||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||||
LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank);
|
LOG_INF("%s: rank %d tokenizing the input ..\n", __func__, my_rank);
|
||||||
|
|
||||||
|
@ -519,26 +518,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
LOG_INF("%s: rank %d tokenization took %g ms\n", __func__, my_rank, 1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
if (my_rank != 0) {
|
||||||
if (n_world > 1) {
|
LOG_INF("perplexity: rank %d waiting for rank 0 to be ready\n", my_rank);
|
||||||
sync_meta meta;
|
}
|
||||||
if (my_rank == 0) {
|
|
||||||
|
if (n_world > 1) {
|
||||||
|
sync_meta meta;
|
||||||
|
|
||||||
|
if (my_rank == 0) {
|
||||||
meta.tokens_size = tokens_size;
|
meta.tokens_size = tokens_size;
|
||||||
llama_send_meta(ctx, &meta, false);
|
meta.n_chunks = params.n_chunks;
|
||||||
} else {
|
|
||||||
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
LOG_INF("%s: rank 0 sending tokens_size = %zu\n", __func__, tokens_size);
|
||||||
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
|
llama_send_meta(ctx, &meta);
|
||||||
|
LOG_INF("%s: rank 0 tokens_size sent successfully\n", __func__);
|
||||||
|
} else {
|
||||||
|
LOG_INF("%s: rank %d waiting 5 seconds for rank 0 to complete tokenization\n", __func__, my_rank);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
||||||
|
LOG_INF("%s: rank %d delay completed, now receiving tokens_size\n", __func__, my_rank);
|
||||||
|
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||||
return { {}, -1.0, {}, {} };
|
return { {}, -1.0, {}, {} };
|
||||||
}
|
}
|
||||||
if (is_last_dev) {
|
tokens_size = meta.tokens_size;
|
||||||
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
|
n_chunks = meta.n_chunks;
|
||||||
} else {
|
if (!is_last_dev) {
|
||||||
tokens_size = meta.tokens_size;
|
LOG_INF("%s: rank %d forwarding tokens_size to next rank\n", __func__, my_rank);
|
||||||
llama_send_meta(ctx, &meta, false);
|
llama_send_meta(ctx, &meta);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size);
|
LOG_INF("%s: rank %d synchronized tokens_size = %zu\n", __func__, my_rank, tokens_size);
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
|
@ -553,14 +562,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
std::vector<float> logit_history;
|
std::vector<float> logit_history;
|
||||||
std::vector<float> prob_history;
|
std::vector<float> prob_history;
|
||||||
|
|
||||||
if (is_last_dev) {
|
if (my_rank == 0) {
|
||||||
logit_history.resize(tokens_size);
|
logit_history.resize(tokens_size);
|
||||||
prob_history.resize(tokens_size);
|
prob_history.resize(tokens_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_chunk_max = tokens.size() / n_ctx;
|
const int n_chunk_max = tokens_size / n_ctx;
|
||||||
|
|
||||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
const int n_chunk = n_chunks < 0 ? n_chunk_max : std::min(n_chunks, n_chunk_max);
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
|
@ -578,9 +587,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
|
||||||
if(is_last_dev){
|
if((my_rank == 0 || is_last_dev)){
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1) {
|
||||||
logits.reserve((size_t)n_ctx * n_vocab);
|
logits.reserve((size_t)n_ctx * n_vocab);
|
||||||
|
LOG_INF("%s: rank %d reserved logits space for %zu elements\n",
|
||||||
|
__func__, my_rank, (size_t)n_ctx * n_vocab);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -590,9 +601,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
std::vector<uint16_t> log_probs; // the log probabilities of logits
|
std::vector<uint16_t> log_probs; // the log probabilities of logits
|
||||||
|
|
||||||
// rank 0 and last rank store log_probs
|
// only rank 0 stores logits/log_probs
|
||||||
// only rank 0 or last device stores logits/log_probs
|
if (!params.logits_file.empty() && (my_rank == 0)) {
|
||||||
if (!params.logits_file.empty() && (is_last_dev || my_rank == 0)) {
|
|
||||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||||
log_probs.resize(n_ctx * nv);
|
log_probs.resize(n_ctx * nv);
|
||||||
|
|
||||||
|
@ -634,14 +644,14 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
clear_meta.clear_kv_cache = true;
|
clear_meta.clear_kv_cache = true;
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
llama_send_meta(ctx, &clear_meta, false);
|
llama_send_meta(ctx, &clear_meta);
|
||||||
} else {
|
} else {
|
||||||
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
|
if (llama_recv_meta(ctx, &clear_meta) == -1) {
|
||||||
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
}
|
}
|
||||||
if (!is_last_dev) {
|
if (!is_last_dev) {
|
||||||
llama_send_meta(ctx, &clear_meta, false);
|
llama_send_meta(ctx, &clear_meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -700,12 +710,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
meta.n_outputs = n_outputs;
|
meta.n_outputs = n_outputs;
|
||||||
|
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
llama_send_meta(ctx, &meta, false); // reverse = false
|
llama_send_meta(ctx, &meta);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
// comms: other ranks receive the batch meta data
|
// comms: other ranks receive the batch meta data
|
||||||
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
if (llama_recv_meta(ctx, &meta) == -1) {
|
||||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
}
|
}
|
||||||
|
@ -732,7 +742,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_last_dev && num_batches > 1 ) {
|
if (my_rank == 0 && num_batches > 1 && n_outputs > 0) {
|
||||||
const int n_outputs_synced = meta.n_outputs;
|
const int n_outputs_synced = meta.n_outputs;
|
||||||
if (n_outputs_synced > 0) {
|
if (n_outputs_synced > 0) {
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
|
@ -754,7 +764,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
LOG("%.2f minutes\n", total_seconds / 60.0);
|
LOG("%.2f minutes\n", total_seconds / 60.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_last_dev) {
|
if (my_rank == 0) {
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||||
|
|
||||||
|
@ -785,37 +795,22 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
logits.clear();
|
logits.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_world > 1) {
|
if (my_rank == 0) {
|
||||||
sync_meta done_meta;
|
double current_ppl = std::exp(nll / count);
|
||||||
done_meta.chunk_done = true;
|
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
|
||||||
|
|
||||||
if (is_last_dev) {
|
LOG_INF("Rank 0: Chunk %d/%d (%.1f%%) completed, current_ppl = %.4lf\n",
|
||||||
// Last device sends completion signal upstream (reverse direction)
|
i + n_seq_batch, n_chunk, progress, current_ppl);
|
||||||
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
|
} else {
|
||||||
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
double progress = ((double)(i + n_seq_batch)) / n_chunk * 100.0;
|
||||||
} else if (my_rank == 0) {
|
|
||||||
// Rank 0 waits for completion signal from downstream
|
LOG_INF("Rank %d: Chunk %d/%d (%.1f%%) completed\n",
|
||||||
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
|
my_rank, i + n_seq_batch, n_chunk, progress);
|
||||||
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
|
||||||
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
|
|
||||||
return {tokens, -1.0, {}, {}};
|
|
||||||
}
|
|
||||||
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
|
|
||||||
} else {
|
|
||||||
// Intermediate ranks: receive from downstream, relay upstream
|
|
||||||
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
|
|
||||||
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
|
||||||
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
|
|
||||||
return {tokens, -1.0, {}, {}};
|
|
||||||
}
|
|
||||||
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
|
|
||||||
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
if (is_last_dev) {
|
if (my_rank == 0) {
|
||||||
nll2 /= count;
|
nll2 /= count;
|
||||||
nll /= count;
|
nll /= count;
|
||||||
const double ppl = exp(nll);
|
const double ppl = exp(nll);
|
||||||
|
@ -2221,19 +2216,22 @@ int main(int argc, char ** argv) {
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
|
|
||||||
if (is_last_dev) {
|
|
||||||
write_logfile(ctx, params, model, results);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
|
write_logfile(ctx, params, model, results);
|
||||||
llama_perf_context_print(ctx);
|
llama_perf_context_print(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
|
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
llama_free_sockets(ctx, nullptr);
|
char * rank0_stop_signal = nullptr;
|
||||||
|
llama_free_sockets(ctx, &rank0_stop_signal);
|
||||||
|
|
||||||
|
if (rank0_stop_signal) {
|
||||||
|
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, rank0_stop_signal);
|
||||||
|
delete[] rank0_stop_signal;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (my_rank != 0 && signal_thread.joinable()) {
|
if (my_rank != 0 && signal_thread.joinable()) {
|
||||||
|
|
|
@ -62,11 +62,11 @@ struct sync_meta {
|
||||||
llama_seq_id ** seq_id = nullptr;
|
llama_seq_id ** seq_id = nullptr;
|
||||||
llama_pos all_pos_0;
|
llama_pos all_pos_0;
|
||||||
llama_pos all_pos_1;
|
llama_pos all_pos_1;
|
||||||
|
|
||||||
uint32_t n_ctx = 0;
|
uint32_t n_ctx = 0;
|
||||||
|
|
||||||
// used for perplexity evaluation
|
// used for perplexity evaluation
|
||||||
int32_t n_outputs;
|
int32_t n_outputs;
|
||||||
bool chunk_done = false; // signal that the chunk is done
|
|
||||||
|
|
||||||
// signal to clear the kv cache
|
// signal to clear the kv cache
|
||||||
bool clear_kv_cache= false;
|
bool clear_kv_cache= false;
|
||||||
|
@ -98,8 +98,9 @@ struct sync_meta {
|
||||||
llama_pos div_p1 = 0;
|
llama_pos div_p1 = 0;
|
||||||
int div_factor = 1;
|
int div_factor = 1;
|
||||||
|
|
||||||
// signal to transfer tokens_size
|
// perplexity evaluation
|
||||||
size_t tokens_size = 0;
|
size_t tokens_size = 0;
|
||||||
|
int n_chunks = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
@ -507,8 +508,8 @@ extern "C" {
|
||||||
|
|
||||||
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
|
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
|
||||||
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
|
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
|
||||||
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
|
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta);
|
||||||
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
|
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta);
|
||||||
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
|
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
|
||||||
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
|
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
|
||||||
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);
|
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);
|
||||||
|
|
177
src/llama.cpp
177
src/llama.cpp
|
@ -3492,8 +3492,7 @@ struct llama_context {
|
||||||
// sockets
|
// sockets
|
||||||
std::string master_ip = "localhost";
|
std::string master_ip = "localhost";
|
||||||
std::string next_node_ip = "localhost";
|
std::string next_node_ip = "localhost";
|
||||||
std::string prev_node_ip = "localhost";
|
uint32_t data_port = 9043;
|
||||||
uint32_t data_port = 9000;
|
|
||||||
uint32_t signal_port = 10000;
|
uint32_t signal_port = 10000;
|
||||||
zmq::context_t * sock_context = nullptr;
|
zmq::context_t * sock_context = nullptr;
|
||||||
zmq::socket_t * send_socket = nullptr;
|
zmq::socket_t * send_socket = nullptr;
|
||||||
|
@ -17872,27 +17871,33 @@ struct input_tensors {
|
||||||
ggml_tensor * inp_pos;
|
ggml_tensor * inp_pos;
|
||||||
};
|
};
|
||||||
|
|
||||||
void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
|
void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
GGML_ASSERT(ctx != nullptr);
|
GGML_ASSERT(ctx != nullptr);
|
||||||
GGML_ASSERT(meta != nullptr);
|
GGML_ASSERT(meta != nullptr);
|
||||||
|
|
||||||
zmq::socket_t * send_socket = reverse ? ctx->reverse_send_socket : ctx->send_socket;
|
zmq::socket_t * send_socket = ctx->send_socket;
|
||||||
GGML_ASSERT(send_socket != nullptr);
|
GGML_ASSERT(send_socket != nullptr);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
std::vector<zmq::message_t> send_msgs;
|
std::vector<zmq::message_t> send_msgs;
|
||||||
|
|
||||||
// Handle chunk_done signal
|
if (meta->clear_kv_cache) {
|
||||||
if (meta->chunk_done) {
|
send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache"));
|
||||||
send_msgs.emplace_back("chunk_done", strlen("chunk_done"));
|
|
||||||
send_msgs.emplace_back("1", 1);
|
send_msgs.emplace_back("1", 1);
|
||||||
zmq::send_multipart(*send_socket, send_msgs);
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (meta->clear_kv_cache) {
|
if (meta->tokens_size > 0) {
|
||||||
send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache"));
|
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
|
||||||
send_msgs.emplace_back("1", 1);
|
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
|
||||||
|
|
||||||
|
if (meta->n_chunks >= 0) {
|
||||||
|
send_msgs.emplace_back("n_chunks", strlen("n_chunks"));
|
||||||
|
send_msgs.emplace_back(&(meta->n_chunks), sizeof(meta->n_chunks));
|
||||||
|
}
|
||||||
|
|
||||||
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17935,11 +17940,6 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (meta->tokens_size > 0) {
|
|
||||||
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
|
|
||||||
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (meta->n_tokens > 0) {
|
if (meta->n_tokens > 0) {
|
||||||
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
|
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
|
||||||
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
|
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
|
||||||
|
@ -18001,35 +18001,30 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
|
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
zmq::socket_t * recv_socket = reverse ? ctx->reverse_recv_socket : ctx->recv_socket;
|
zmq::socket_t * recv_socket = ctx->recv_socket;
|
||||||
GGML_ASSERT(recv_socket != nullptr);
|
GGML_ASSERT(recv_socket != nullptr);
|
||||||
|
|
||||||
recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
|
recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
|
||||||
|
|
||||||
std::vector<zmq::message_t> recv_msgs;
|
std::vector<zmq::message_t> recv_msgs;
|
||||||
|
|
||||||
if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) {
|
if (!zmq::recv_multipart(*recv_socket, std::back_inserter(recv_msgs))) {
|
||||||
recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error
|
recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
||||||
|
|
||||||
const std::string cmd = recv_msgs[0].to_string();
|
const std::string cmd = recv_msgs[0].to_string();
|
||||||
size_t idx = 1;
|
size_t idx = 1;
|
||||||
|
|
||||||
// Handle chunk_done signal
|
if (cmd == "clear_kv_cache" && recv_msgs.size() == 2) {
|
||||||
if (cmd == "chunk_done") {
|
|
||||||
meta->chunk_done = true;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
|
|
||||||
meta->clear_kv_cache = true;
|
meta->clear_kv_cache = true;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
|
if (cmd == "kv_seq_rm" && recv_msgs.size() == 4) {
|
||||||
meta->kv_seq_rm = true;
|
meta->kv_seq_rm = true;
|
||||||
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
|
std::memcpy(&meta->rm_seq_id, recv_msgs[idx++].data(), sizeof(meta->rm_seq_id));
|
||||||
|
@ -18076,22 +18071,17 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse =
|
||||||
std::string key = recv_msgs[i].to_string();
|
std::string key = recv_msgs[i].to_string();
|
||||||
zmq::message_t & data_msg = recv_msgs[i + 1];
|
zmq::message_t & data_msg = recv_msgs[i + 1];
|
||||||
|
|
||||||
if (key == "tokens_size") {
|
if (key == "n_tokens") {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->tokens_size));
|
|
||||||
std::memcpy(&(meta->tokens_size), data_msg.data(), sizeof(meta->tokens_size));
|
|
||||||
}
|
|
||||||
else if (key == "n_tokens") {
|
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
|
||||||
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
|
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
|
||||||
|
} else if (key == "n_chunks") {
|
||||||
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_chunks));
|
||||||
|
std::memcpy(&(meta->n_chunks), data_msg.data(), sizeof(meta->n_chunks));
|
||||||
}
|
}
|
||||||
else if (key == "n_outputs") {
|
else if (key == "n_outputs") {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs));
|
||||||
std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs));
|
std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs));
|
||||||
}
|
}
|
||||||
// else if (key == "chunk_start_pos") {
|
|
||||||
// GGML_ASSERT(data_msg.size() == sizeof(meta->chunk_start_pos));
|
|
||||||
// std::memcpy(&(meta->chunk_start_pos), data_msg.data(), sizeof(meta->chunk_start_pos));
|
|
||||||
// }
|
|
||||||
else if (key == "n_ctx") {
|
else if (key == "n_ctx") {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx));
|
||||||
std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx));
|
std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx));
|
||||||
|
@ -18416,7 +18406,7 @@ static int llama_decode_internal(
|
||||||
bool is_last_dev = (my_rank == n_world - 1);
|
bool is_last_dev = (my_rank == n_world - 1);
|
||||||
|
|
||||||
if (my_rank != 0) {
|
if (my_rank != 0) {
|
||||||
if (llama_recv_meta(&lctx, &meta, false) == -1) {
|
if (llama_recv_meta(&lctx, &meta) == -1) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18477,7 +18467,7 @@ static int llama_decode_internal(
|
||||||
meta.pos = batch_all.pos;
|
meta.pos = batch_all.pos;
|
||||||
meta.all_pos_0 = batch_all.all_pos_0;
|
meta.all_pos_0 = batch_all.all_pos_0;
|
||||||
meta.all_pos_1 = batch_all.all_pos_1;
|
meta.all_pos_1 = batch_all.all_pos_1;
|
||||||
llama_send_meta(&lctx, &meta, false);
|
llama_send_meta(&lctx, &meta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20615,15 +20605,10 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->sock_context = new zmq::context_t(2);
|
ctx->sock_context = new zmq::context_t(2);
|
||||||
|
|
||||||
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||||
ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
||||||
ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
||||||
|
|
||||||
// Reverse pipeline sockets (new - for barriers)
|
|
||||||
ctx->reverse_send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
|
||||||
ctx->reverse_recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
|
||||||
|
|
||||||
if (my_rank != 0 && my_rank != (n_world - 1)) {
|
if (my_rank != 0 && my_rank != (n_world - 1)) {
|
||||||
ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||||
} else if (my_rank == (n_world - 1)) {
|
} else if (my_rank == (n_world - 1)) {
|
||||||
|
@ -20631,38 +20616,18 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t next_rank = (my_rank + 1) % n_world;
|
const uint32_t next_rank = (my_rank + 1) % n_world;
|
||||||
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
|
|
||||||
|
|
||||||
std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port));
|
std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port));
|
||||||
std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
|
std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
|
||||||
std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port));
|
std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port));
|
||||||
std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port));
|
std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port));
|
||||||
|
|
||||||
// Reverse pipeline endpoints (new)
|
|
||||||
// Use a different port offset for reverse communication to avoid conflicts
|
|
||||||
const uint32_t reverse_port_offset = 1000;
|
|
||||||
std::string reverse_recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port + reverse_port_offset));
|
|
||||||
std::string reverse_send_endp = "tcp://" + ctx->prev_node_ip + ":" + std::to_string(map_rank_to_port(prev_rank, ctx->data_port + reverse_port_offset));
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ctx->recv_socket->bind(recv_endp);
|
ctx->recv_socket->bind(recv_endp);
|
||||||
ctx->signal_socket->bind(signal_endp);
|
ctx->signal_socket->bind(signal_endp);
|
||||||
|
|
||||||
ctx->send_socket->connect(send_endp);
|
ctx->send_socket->connect(send_endp);
|
||||||
if (ctx->master_socket && my_rank != (n_world - 1)) {
|
if (ctx->master_socket && my_rank != (n_world - 1)) {
|
||||||
ctx->master_socket->connect(master_endp);
|
ctx->master_socket->connect(master_endp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup reverse pipeline sockets
|
|
||||||
if (my_rank > 0) {
|
|
||||||
// All ranks except rank 0 can send to previous rank
|
|
||||||
ctx->reverse_send_socket->connect(reverse_send_endp);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (my_rank < n_world - 1) {
|
|
||||||
// All ranks except last rank can receive from next rank
|
|
||||||
ctx->reverse_recv_socket->bind(reverse_recv_endp);
|
|
||||||
}
|
|
||||||
} catch (const zmq::error_t &e) {
|
} catch (const zmq::error_t &e) {
|
||||||
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
|
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
|
||||||
exit(1);
|
exit(1);
|
||||||
|
@ -20936,7 +20901,6 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
||||||
const uint32_t my_rank = ctx->cparams.rank;
|
const uint32_t my_rank = ctx->cparams.rank;
|
||||||
// to adapt to the new topology, use old next_rank
|
// to adapt to the new topology, use old next_rank
|
||||||
const uint32_t next_rank = ctx->cparams.original_next_rank;
|
const uint32_t next_rank = ctx->cparams.original_next_rank;
|
||||||
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
|
|
||||||
|
|
||||||
if (n_world == 1) {
|
if (n_world == 1) {
|
||||||
return;
|
return;
|
||||||
|
@ -20960,89 +20924,6 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
||||||
*msg = new char[msg_str.size() + 1];
|
*msg = new char[msg_str.size() + 1];
|
||||||
std::strcpy(*msg, msg_str.c_str());
|
std::strcpy(*msg, msg_str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send shutdown signal through reverse pipeline as well
|
|
||||||
if (my_rank == n_world - 1) {
|
|
||||||
// Last rank initiates reverse shutdown
|
|
||||||
try {
|
|
||||||
sync_meta shutdown_meta;
|
|
||||||
shutdown_meta.chunk_done = true; // Reuse chunk_done as shutdown signal
|
|
||||||
llama_send_meta(ctx, &shutdown_meta, true); // reverse = true
|
|
||||||
} catch (const zmq::error_t &e) {
|
|
||||||
LLAMA_LOG_INFO("Error sending reverse shutdown signal: %s", e.what());
|
|
||||||
}
|
|
||||||
} else if (my_rank > 0) {
|
|
||||||
// Intermediate ranks relay reverse shutdown signal
|
|
||||||
try {
|
|
||||||
sync_meta shutdown_meta;
|
|
||||||
// Set a short timeout for shutdown
|
|
||||||
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, 500);
|
|
||||||
|
|
||||||
if (llama_recv_meta(ctx, &shutdown_meta, true) == 0) {
|
|
||||||
if (my_rank > 0) {
|
|
||||||
llama_send_meta(ctx, &shutdown_meta, true); // relay upstream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset timeout
|
|
||||||
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
|
||||||
} catch (const zmq::error_t &e) {
|
|
||||||
LLAMA_LOG_INFO("Error handling reverse shutdown signal on rank %d: %s", my_rank, e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Close signal sender (local socket created in this function)
|
|
||||||
signal_sender.close();
|
|
||||||
|
|
||||||
// Close reverse sockets first
|
|
||||||
if (ctx->reverse_send_socket) {
|
|
||||||
ctx->reverse_send_socket->close();
|
|
||||||
delete ctx->reverse_send_socket;
|
|
||||||
ctx->reverse_send_socket = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx->reverse_recv_socket) {
|
|
||||||
ctx->reverse_recv_socket->close();
|
|
||||||
delete ctx->reverse_recv_socket;
|
|
||||||
ctx->reverse_recv_socket = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close existing forward sockets
|
|
||||||
if (ctx->send_socket) {
|
|
||||||
ctx->send_socket->close();
|
|
||||||
delete ctx->send_socket;
|
|
||||||
ctx->send_socket = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx->recv_socket) {
|
|
||||||
ctx->recv_socket->close();
|
|
||||||
delete ctx->recv_socket;
|
|
||||||
ctx->recv_socket = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx->signal_socket) {
|
|
||||||
ctx->signal_socket->close();
|
|
||||||
delete ctx->signal_socket;
|
|
||||||
ctx->signal_socket = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle master_socket cleanup (be careful not to double-delete)
|
|
||||||
if (ctx->master_socket && my_rank != (n_world - 1) && ctx->master_socket != ctx->send_socket) {
|
|
||||||
ctx->master_socket->close();
|
|
||||||
delete ctx->master_socket;
|
|
||||||
ctx->master_socket = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup ZMQ context last
|
|
||||||
if (ctx->sock_context) {
|
|
||||||
delete ctx->sock_context;
|
|
||||||
ctx->sock_context = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (const zmq::error_t &e) {
|
|
||||||
LLAMA_LOG_INFO("Error cleaning up sockets: %s", e.what());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
|
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue