mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-10 21:34:37 +00:00
Refactored the logic related to communication content and timing control
This commit is contained in:
parent
4b823775ec
commit
a3becb586a
5 changed files with 474 additions and 134 deletions
|
@ -1992,6 +1992,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
cparams.keep_out_in_metal = params.keep_out_in_metal;
|
cparams.keep_out_in_metal = params.keep_out_in_metal;
|
||||||
cparams.n_gpu_layers = params.n_gpu_layers;
|
cparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
cparams.n_cycles = params.n_cycles;
|
cparams.n_cycles = params.n_cycles;
|
||||||
|
cparams.is_perplexity_eval= params.is_perplexity_eval;
|
||||||
std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), cparams.n_layer_window);
|
std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), cparams.n_layer_window);
|
||||||
|
|
||||||
if (cparams.master_ip != nullptr) {
|
if (cparams.master_ip != nullptr) {
|
||||||
|
|
|
@ -178,6 +178,8 @@ struct gpt_params {
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||||
|
|
||||||
|
bool is_perplexity_eval;
|
||||||
|
|
||||||
struct cpu_params cpuparams;
|
struct cpu_params cpuparams;
|
||||||
struct cpu_params cpuparams_batch;
|
struct cpu_params cpuparams_batch;
|
||||||
struct cpu_params draft_cpuparams;
|
struct cpu_params draft_cpuparams;
|
||||||
|
|
|
@ -524,9 +524,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
sync_meta meta;
|
sync_meta meta;
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
meta.tokens_size = tokens_size;
|
meta.tokens_size = tokens_size;
|
||||||
llama_send_meta(ctx, &meta);
|
llama_send_meta(ctx, &meta, false);
|
||||||
} else {
|
} else {
|
||||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
||||||
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
|
LOG_ERR("%s: failed to receive tokens_size on rank %d\n", __func__, my_rank);
|
||||||
return { {}, -1.0, {}, {} };
|
return { {}, -1.0, {}, {} };
|
||||||
}
|
}
|
||||||
|
@ -534,7 +534,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
|
GGML_ASSERT(tokens_size == meta.tokens_size && "Token size mismatch between rank 0 and last rank!");
|
||||||
} else {
|
} else {
|
||||||
tokens_size = meta.tokens_size;
|
tokens_size = meta.tokens_size;
|
||||||
llama_send_meta(ctx, &meta);
|
llama_send_meta(ctx, &meta, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -628,19 +628,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
{
|
{
|
||||||
|
// synvhronize the KV cache clear signal across all ranks
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
sync_meta clear_meta;
|
sync_meta clear_meta;
|
||||||
clear_meta.clear_kv_cache = true;
|
clear_meta.clear_kv_cache = true;
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
llama_send_meta(ctx, &clear_meta);
|
llama_send_meta(ctx, &clear_meta, false);
|
||||||
} else {
|
} else {
|
||||||
if (llama_recv_meta(ctx, &clear_meta) == -1) {
|
if (llama_recv_meta(ctx, &clear_meta, false) == -1) {
|
||||||
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
}
|
}
|
||||||
if (!is_last_dev) {
|
if (!is_last_dev) {
|
||||||
llama_send_meta(ctx, &clear_meta);
|
llama_send_meta(ctx, &clear_meta, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -648,11 +649,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
sync_meta meta;
|
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
// used for communication of the batch meta data
|
||||||
|
sync_meta meta;
|
||||||
|
|
||||||
int n_outputs = 0;
|
int n_outputs = 0;
|
||||||
|
|
||||||
|
@ -689,36 +690,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
// restore the original token in case it was set to BOS
|
// restore the original token in case it was set to BOS
|
||||||
tokens[seq_start] = token_org;
|
tokens[seq_start] = token_org;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (my_rank == 0) {
|
// comms: now rank 0 need to send the batch to other ranks
|
||||||
// Required batch info: Operation scale, KV cache location, Logits calculation location
|
|
||||||
meta.n_ctx = n_ctx;
|
|
||||||
meta.n_tokens = batch.n_tokens;
|
meta.n_tokens = batch.n_tokens;
|
||||||
meta.pos = batch.pos;
|
meta.pos = batch.pos;
|
||||||
|
meta.n_seq_id = batch.n_seq_id;
|
||||||
|
meta.seq_id = batch.seq_id;
|
||||||
meta.logits = batch.logits;
|
meta.logits = batch.logits;
|
||||||
meta.all_pos_0 = batch.all_pos_0;
|
|
||||||
meta.all_pos_1 = batch.all_pos_1;
|
|
||||||
meta.n_outputs = n_outputs;
|
meta.n_outputs = n_outputs;
|
||||||
meta.chunk_start_pos = start;
|
|
||||||
}
|
|
||||||
|
|
||||||
// other ranks need to know batch info
|
|
||||||
{
|
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
meta.n_ctx = n_ctx;
|
llama_send_meta(ctx, &meta, false); // reverse = false
|
||||||
|
}
|
||||||
if (my_rank == 0) {
|
|
||||||
llama_send_meta(ctx, &meta);
|
|
||||||
} else {
|
} else {
|
||||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
if (n_world > 1) {
|
||||||
|
// comms: other ranks receive the batch meta data
|
||||||
|
if (llama_recv_meta(ctx, &meta, false) == -1) {
|
||||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
}
|
}
|
||||||
if (!is_last_dev) {
|
|
||||||
llama_send_meta(ctx, &meta);
|
// copy the batch meta data to the llama_batch
|
||||||
|
if (meta.n_tokens > 0) {
|
||||||
|
batch.n_tokens = meta.n_tokens;
|
||||||
|
if (meta.pos) { std::memcpy(batch.pos, meta.pos, meta.n_tokens * sizeof(llama_pos)); } // use n_tokens instead of n_batch, n_tokens is the actual number of tokens in the batch
|
||||||
|
if (meta.n_seq_id) { std::memcpy(batch.n_seq_id, meta.n_seq_id, meta.n_tokens * sizeof(int32_t)); }
|
||||||
|
if (meta.seq_id) {
|
||||||
|
const int32_t n_seq_max = 1;
|
||||||
|
for (int32_t i = 0; i < meta.n_tokens; ++i) {
|
||||||
|
std::memcpy(batch.seq_id[i], meta.seq_id[i], n_seq_max * sizeof(llama_seq_id));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (meta.logits) { std::memcpy(batch.logits, meta.logits, meta.n_tokens * sizeof(int8_t)); }
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -753,8 +758,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||||
|
|
||||||
int chunk_start_pos = meta.chunk_start_pos;
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||||
llama_token * tokens_data = tokens.data() + chunk_start_pos + seq*n_ctx + first;
|
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
process_logits(logits_stream, n_vocab, all_logits,
|
process_logits(logits_stream, n_vocab, all_logits,
|
||||||
tokens_data, n_ctx - 1 - first,
|
tokens_data, n_ctx - 1 - first,
|
||||||
|
@ -763,8 +767,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
process_logits(n_vocab, all_logits,
|
process_logits(n_vocab, all_logits,
|
||||||
tokens_data, n_ctx - 1 - first,
|
tokens_data, n_ctx - 1 - first,
|
||||||
workers, nll, nll2,
|
workers, nll, nll2,
|
||||||
logit_history.data() + chunk_start_pos + seq*n_ctx + first,
|
logit_history.data() + start + seq*n_ctx + first,
|
||||||
prob_history.data() + chunk_start_pos + seq*n_ctx + first);
|
prob_history.data() + start + seq*n_ctx + first);
|
||||||
}
|
}
|
||||||
count += n_ctx - first - 1;
|
count += n_ctx - first - 1;
|
||||||
|
|
||||||
|
@ -778,9 +782,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
logits.clear();
|
logits.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (n_world > 1) {
|
||||||
|
sync_meta done_meta;
|
||||||
|
done_meta.chunk_done = true;
|
||||||
|
|
||||||
|
if (is_last_dev) {
|
||||||
|
// Last device sends completion signal upstream (reverse direction)
|
||||||
|
LOG_INF("Rank %d: Sending chunk_done signal for chunk %d\n", my_rank, i);
|
||||||
|
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
||||||
|
} else if (my_rank == 0) {
|
||||||
|
// Rank 0 waits for completion signal from downstream
|
||||||
|
LOG_INF("Rank 0: Waiting for chunk_done signal for chunk %d\n", i);
|
||||||
|
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
||||||
|
LOG_ERR("Failed to recv chunk_done signal on rank 0 for chunk %d\n", i);
|
||||||
|
return {tokens, -1.0, {}, {}};
|
||||||
|
}
|
||||||
|
LOG_INF("Rank 0: Received chunk_done signal for chunk %d\n", i);
|
||||||
|
} else {
|
||||||
|
// Intermediate ranks: receive from downstream, relay upstream
|
||||||
|
LOG_INF("Rank %d: Waiting for chunk_done signal for chunk %d\n", my_rank, i);
|
||||||
|
if (llama_recv_meta(ctx, &done_meta, true) == -1 || !done_meta.chunk_done) { // reverse = true
|
||||||
|
LOG_ERR("Failed to recv chunk_done signal on rank %d for chunk %d\n", my_rank, i);
|
||||||
|
return {tokens, -1.0, {}, {}};
|
||||||
|
}
|
||||||
|
LOG_INF("Rank %d: Relaying chunk_done signal for chunk %d\n", my_rank, i);
|
||||||
|
llama_send_meta(ctx, &done_meta, true); // reverse = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
if (is_last_dev) {
|
if (is_last_dev) {
|
||||||
|
@ -795,12 +827,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
} else {
|
} else {
|
||||||
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
||||||
}
|
}
|
||||||
|
llama_batch_free(batch);
|
||||||
return {tokens, ppl, logit_history, prob_history};
|
return {tokens, ppl, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2078,6 +2109,7 @@ int main(int argc, char ** argv) {
|
||||||
params.n_ctx = 512;
|
params.n_ctx = 512;
|
||||||
params.logits_all = true;
|
params.logits_all = true;
|
||||||
params.escape = false;
|
params.escape = false;
|
||||||
|
params.is_perplexity_eval = true;
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
|
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -2141,7 +2173,6 @@ int main(int argc, char ** argv) {
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
// update rank and world size if any devices removed
|
|
||||||
my_rank = params.rank;
|
my_rank = params.rank;
|
||||||
n_world = params.n_world;
|
n_world = params.n_world;
|
||||||
|
|
||||||
|
@ -2189,14 +2220,34 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
|
|
||||||
if (is_last_dev) {
|
if (is_last_dev) {
|
||||||
llama_perf_context_print(ctx);
|
|
||||||
write_logfile(ctx, params, model, results);
|
write_logfile(ctx, params, model, results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (my_rank == 0) {
|
||||||
|
llama_perf_context_print(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_world > 1) {
|
||||||
|
LOG_INF("Rank %d: Entering distributed shutdown protocol.\n", my_rank);
|
||||||
|
|
||||||
|
if (my_rank == 0) {
|
||||||
|
llama_free_sockets(ctx, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (my_rank != 0 && signal_thread.joinable()) {
|
||||||
|
signal_thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stop_signal) {
|
||||||
|
LOG_INF("Rank %d: Cleanup signal received: %s\n", my_rank, stop_signal);
|
||||||
|
delete[] stop_signal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -58,12 +58,15 @@ struct sync_meta {
|
||||||
|
|
||||||
int8_t * logits = nullptr;
|
int8_t * logits = nullptr;
|
||||||
llama_pos * pos = nullptr;
|
llama_pos * pos = nullptr;
|
||||||
|
int32_t * n_seq_id = nullptr;
|
||||||
|
llama_seq_id ** seq_id = nullptr;
|
||||||
llama_pos all_pos_0;
|
llama_pos all_pos_0;
|
||||||
llama_pos all_pos_1;
|
llama_pos all_pos_1;
|
||||||
uint32_t n_ctx = 0;
|
uint32_t n_ctx = 0;
|
||||||
|
|
||||||
int chunk_start_pos;
|
// used for perplexity evaluation
|
||||||
int32_t n_outputs; // Used to pass the number of logits to be outputted
|
int32_t n_outputs;
|
||||||
|
bool chunk_done = false; // signal that the chunk is done
|
||||||
|
|
||||||
// signal to clear the kv cache
|
// signal to clear the kv cache
|
||||||
bool clear_kv_cache= false;
|
bool clear_kv_cache= false;
|
||||||
|
@ -389,6 +392,7 @@ extern "C" {
|
||||||
int32_t n_threads; // number of threads to use for generation
|
int32_t n_threads; // number of threads to use for generation
|
||||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||||
|
|
||||||
|
|
||||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||||
enum llama_attention_type attention_type; // attention type to use for embeddings
|
enum llama_attention_type attention_type; // attention type to use for embeddings
|
||||||
|
@ -422,6 +426,7 @@ extern "C" {
|
||||||
// currently works only with CPU execution
|
// currently works only with CPU execution
|
||||||
ggml_abort_callback abort_callback;
|
ggml_abort_callback abort_callback;
|
||||||
void * abort_callback_data;
|
void * abort_callback_data;
|
||||||
|
bool is_perplexity_eval; // whether to run in perplexity evaluation mode
|
||||||
};
|
};
|
||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
|
@ -502,8 +507,8 @@ extern "C" {
|
||||||
|
|
||||||
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
|
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
|
||||||
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
|
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
|
||||||
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta);
|
LLAMA_API int llama_recv_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
|
||||||
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta);
|
LLAMA_API void llama_send_meta (struct llama_context * ctx, struct sync_meta * meta, bool reverse);
|
||||||
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
|
LLAMA_API int llama_gather_device_info(struct llama_context * ctx, struct device_info * dev_info_set);
|
||||||
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
|
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
|
||||||
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);
|
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);
|
||||||
|
|
321
src/llama.cpp
321
src/llama.cpp
|
@ -2607,6 +2607,8 @@ struct llama_cparams {
|
||||||
int n_threads; // number of threads to use for generation
|
int n_threads; // number of threads to use for generation
|
||||||
int n_threads_batch; // number of threads to use for batch processing
|
int n_threads_batch; // number of threads to use for batch processing
|
||||||
|
|
||||||
|
bool is_perplexity_eval;
|
||||||
|
|
||||||
float rope_freq_base;
|
float rope_freq_base;
|
||||||
float rope_freq_scale;
|
float rope_freq_scale;
|
||||||
|
|
||||||
|
@ -3490,6 +3492,7 @@ struct llama_context {
|
||||||
// sockets
|
// sockets
|
||||||
std::string master_ip = "localhost";
|
std::string master_ip = "localhost";
|
||||||
std::string next_node_ip = "localhost";
|
std::string next_node_ip = "localhost";
|
||||||
|
std::string prev_node_ip = "localhost";
|
||||||
uint32_t data_port = 9000;
|
uint32_t data_port = 9000;
|
||||||
uint32_t signal_port = 10000;
|
uint32_t signal_port = 10000;
|
||||||
zmq::context_t * sock_context = nullptr;
|
zmq::context_t * sock_context = nullptr;
|
||||||
|
@ -3497,6 +3500,9 @@ struct llama_context {
|
||||||
zmq::socket_t * recv_socket = nullptr;
|
zmq::socket_t * recv_socket = nullptr;
|
||||||
zmq::socket_t * master_socket = nullptr;
|
zmq::socket_t * master_socket = nullptr;
|
||||||
zmq::socket_t * signal_socket = nullptr;
|
zmq::socket_t * signal_socket = nullptr;
|
||||||
|
// Add these for reverse communication
|
||||||
|
zmq::socket_t * reverse_send_socket = nullptr; // Reverse: Rank i -> Rank i-1
|
||||||
|
zmq::socket_t * reverse_recv_socket = nullptr; // Reverse: Rank i <- Rank i+1
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_lora_weight {
|
struct llama_lora_weight {
|
||||||
|
@ -17866,24 +17872,86 @@ struct input_tensors {
|
||||||
ggml_tensor * inp_pos;
|
ggml_tensor * inp_pos;
|
||||||
};
|
};
|
||||||
|
|
||||||
void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
void llama_send_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
|
||||||
GGML_ASSERT(ctx != nullptr);
|
GGML_ASSERT(ctx != nullptr);
|
||||||
GGML_ASSERT(meta != nullptr);
|
GGML_ASSERT(meta != nullptr);
|
||||||
|
|
||||||
zmq::socket_t * send_socket = ctx->send_socket;
|
zmq::socket_t * send_socket = reverse ? ctx->reverse_send_socket : ctx->send_socket;
|
||||||
GGML_ASSERT(send_socket != nullptr);
|
GGML_ASSERT(send_socket != nullptr);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
std::vector<zmq::message_t> send_msgs;
|
std::vector<zmq::message_t> send_msgs;
|
||||||
|
|
||||||
GGML_ASSERT(meta->n_tokens != 0);
|
// Handle chunk_done signal
|
||||||
|
if (meta->chunk_done) {
|
||||||
|
send_msgs.emplace_back("chunk_done", strlen("chunk_done"));
|
||||||
|
send_msgs.emplace_back("1", 1);
|
||||||
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->clear_kv_cache) {
|
||||||
|
send_msgs.emplace_back("clear_kv_cache", strlen("clear_kv_cache"));
|
||||||
|
send_msgs.emplace_back("1", 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->kv_seq_rm) {
|
||||||
|
send_msgs.emplace_back("kv_seq_rm", strlen("kv_seq_rm"));
|
||||||
|
send_msgs.emplace_back(&(meta->rm_seq_id), sizeof(meta->rm_seq_id));
|
||||||
|
send_msgs.emplace_back(&(meta->rm_p0), sizeof(meta->rm_p0));
|
||||||
|
send_msgs.emplace_back(&(meta->rm_p1), sizeof(meta->rm_p1));
|
||||||
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->kv_seq_add) {
|
||||||
|
send_msgs.emplace_back("kv_seq_add", strlen("kv_seq_add"));
|
||||||
|
send_msgs.emplace_back(&(meta->add_seq_id), sizeof(meta->add_seq_id));
|
||||||
|
send_msgs.emplace_back(&(meta->add_p0), sizeof(meta->add_p0));
|
||||||
|
send_msgs.emplace_back(&(meta->add_p1), sizeof(meta->add_p1));
|
||||||
|
send_msgs.emplace_back(&(meta->add_delta), sizeof(meta->add_delta));
|
||||||
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->kv_seq_cp) {
|
||||||
|
send_msgs.emplace_back("kv_seq_cp", strlen("kv_seq_cp"));
|
||||||
|
send_msgs.emplace_back(&(meta->cp_src_seq_id), sizeof(meta->cp_src_seq_id));
|
||||||
|
send_msgs.emplace_back(&(meta->cp_dst_seq_id), sizeof(meta->cp_dst_seq_id));
|
||||||
|
send_msgs.emplace_back(&(meta->cp_p0), sizeof(meta->cp_p0));
|
||||||
|
send_msgs.emplace_back(&(meta->cp_p1), sizeof(meta->cp_p1));
|
||||||
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->kv_seq_div) {
|
||||||
|
send_msgs.emplace_back("kv_seq_div", strlen("kv_seq_div"));
|
||||||
|
send_msgs.emplace_back(&(meta->div_seq_id), sizeof(meta->div_seq_id));
|
||||||
|
send_msgs.emplace_back(&(meta->div_p0), sizeof(meta->div_p0));
|
||||||
|
send_msgs.emplace_back(&(meta->div_p1), sizeof(meta->div_p1));
|
||||||
|
send_msgs.emplace_back(&(meta->div_factor), sizeof(meta->div_factor));
|
||||||
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->tokens_size > 0) {
|
||||||
|
send_msgs.emplace_back("tokens_size", strlen("tokens_size"));
|
||||||
|
send_msgs.emplace_back(&(meta->tokens_size), sizeof(meta->tokens_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta->n_tokens > 0) {
|
||||||
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
|
send_msgs.emplace_back("n_tokens", strlen("n_tokens"));
|
||||||
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
|
send_msgs.emplace_back(&(meta->n_tokens), sizeof(meta->n_tokens));
|
||||||
|
|
||||||
if (meta->pos != nullptr) {
|
send_msgs.emplace_back("n_outputs", strlen("n_outputs"));
|
||||||
send_msgs.emplace_back("pos", strlen("pos"));
|
send_msgs.emplace_back(&(meta->n_outputs), sizeof(meta->n_outputs));
|
||||||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
|
||||||
}
|
// send_msgs.emplace_back("chunk_start_pos", strlen("chunk_start_pos"));
|
||||||
|
// send_msgs.emplace_back(&(meta->chunk_start_pos), sizeof(meta->chunk_start_pos));
|
||||||
|
|
||||||
|
send_msgs.emplace_back("n_ctx", strlen("n_ctx"));
|
||||||
|
send_msgs.emplace_back(&(meta->n_ctx), sizeof(meta->n_ctx));
|
||||||
|
|
||||||
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
send_msgs.emplace_back("all_pos_0", strlen("all_pos_0"));
|
||||||
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
send_msgs.emplace_back(&(meta->all_pos_0), sizeof(meta->all_pos_0));
|
||||||
|
@ -17891,6 +17959,40 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
send_msgs.emplace_back("all_pos_1", strlen("all_pos_1"));
|
||||||
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
send_msgs.emplace_back(&(meta->all_pos_1), sizeof(meta->all_pos_1));
|
||||||
|
|
||||||
|
// batch.pos
|
||||||
|
if (meta->pos != nullptr) {
|
||||||
|
send_msgs.emplace_back("pos", strlen("pos"));
|
||||||
|
send_msgs.emplace_back(meta->pos, meta->n_tokens * sizeof(llama_pos));
|
||||||
|
}
|
||||||
|
// batch.n_seq_id
|
||||||
|
if (meta->n_seq_id != nullptr) {
|
||||||
|
send_msgs.emplace_back("n_seq_id", strlen("n_seq_id"));
|
||||||
|
send_msgs.emplace_back(meta->n_seq_id, meta->n_tokens * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
// batch.seq_id
|
||||||
|
if (meta->seq_id != nullptr) {
|
||||||
|
const int32_t n_tokens = meta->n_tokens;
|
||||||
|
const int32_t n_seq_max = 1;
|
||||||
|
|
||||||
|
std::vector<llama_seq_id> flat_seq_ids;
|
||||||
|
flat_seq_ids.reserve(n_tokens * n_seq_max);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||||
|
for (int32_t j = 0; j < n_seq_max; ++j) {
|
||||||
|
flat_seq_ids.push_back(meta->seq_id[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
send_msgs.emplace_back("seq_id", strlen("seq_id"));
|
||||||
|
send_msgs.emplace_back(flat_seq_ids.data(), flat_seq_ids.size() * sizeof(llama_seq_id));
|
||||||
|
}
|
||||||
|
// batch.logits
|
||||||
|
if (meta->logits != nullptr) {
|
||||||
|
send_msgs.emplace_back("logits", strlen("logits"));
|
||||||
|
send_msgs.emplace_back(meta->logits, meta->n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!send_msgs.empty()) {
|
if (!send_msgs.empty()) {
|
||||||
zmq::send_multipart(*send_socket, send_msgs);
|
zmq::send_multipart(*send_socket, send_msgs);
|
||||||
}
|
}
|
||||||
|
@ -17899,12 +18001,16 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta, bool reverse = false) {
|
||||||
ctx->recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
|
zmq::socket_t * recv_socket = reverse ? ctx->reverse_recv_socket : ctx->recv_socket;
|
||||||
|
GGML_ASSERT(recv_socket != nullptr);
|
||||||
|
|
||||||
|
recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
|
||||||
|
|
||||||
std::vector<zmq::message_t> recv_msgs;
|
std::vector<zmq::message_t> recv_msgs;
|
||||||
|
|
||||||
if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) {
|
if (!zmq::recv_multipart(*(ctx->recv_socket), std::back_inserter(recv_msgs))) {
|
||||||
|
recv_socket->set(zmq::sockopt::rcvtimeo, -1); // Reset timeout to blocking mode before returning error
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17913,6 +18019,12 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
const std::string cmd = recv_msgs[0].to_string();
|
const std::string cmd = recv_msgs[0].to_string();
|
||||||
size_t idx = 1;
|
size_t idx = 1;
|
||||||
|
|
||||||
|
// Handle chunk_done signal
|
||||||
|
if (cmd == "chunk_done") {
|
||||||
|
meta->chunk_done = true;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
|
if (cmd == "clear_kv_cache" && recv_msgs.size() == 1) {
|
||||||
meta->clear_kv_cache = true;
|
meta->clear_kv_cache = true;
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -17953,29 +18065,82 @@ int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (recv_msgs.size() % 2 != 0) {
|
||||||
|
LLAMA_LOG_ERROR("Invalid message format: odd number of messages\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
|
for (size_t i = 0; i < recv_msgs.size(); i += 2) {
|
||||||
|
if (i + 1 >= recv_msgs.size()) break;
|
||||||
|
|
||||||
std::string key = recv_msgs[i].to_string();
|
std::string key = recv_msgs[i].to_string();
|
||||||
zmq::message_t & data_msg = recv_msgs[i + 1];
|
zmq::message_t & data_msg = recv_msgs[i + 1];
|
||||||
|
|
||||||
if (key == "n_tokens") {
|
if (key == "tokens_size") {
|
||||||
|
GGML_ASSERT(data_msg.size() == sizeof(meta->tokens_size));
|
||||||
|
std::memcpy(&(meta->tokens_size), data_msg.data(), sizeof(meta->tokens_size));
|
||||||
|
}
|
||||||
|
else if (key == "n_tokens") {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_tokens));
|
||||||
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
|
std::memcpy(&(meta->n_tokens), data_msg.data(), sizeof(meta->n_tokens));
|
||||||
}
|
}
|
||||||
|
else if (key == "n_outputs") {
|
||||||
if (key == "pos") {
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_outputs));
|
||||||
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos));
|
std::memcpy(&(meta->n_outputs), data_msg.data(), sizeof(meta->n_outputs));
|
||||||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
|
||||||
}
|
}
|
||||||
|
// else if (key == "chunk_start_pos") {
|
||||||
if (key == "all_pos_0") {
|
// GGML_ASSERT(data_msg.size() == sizeof(meta->chunk_start_pos));
|
||||||
|
// std::memcpy(&(meta->chunk_start_pos), data_msg.data(), sizeof(meta->chunk_start_pos));
|
||||||
|
// }
|
||||||
|
else if (key == "n_ctx") {
|
||||||
|
GGML_ASSERT(data_msg.size() == sizeof(meta->n_ctx));
|
||||||
|
std::memcpy(&(meta->n_ctx), data_msg.data(), sizeof(meta->n_ctx));
|
||||||
|
}
|
||||||
|
else if (key == "all_pos_0") {
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_0));
|
||||||
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
std::memcpy(&(meta->all_pos_0), data_msg.data(), sizeof(meta->all_pos_0));
|
||||||
}
|
}
|
||||||
|
else if (key == "all_pos_1") {
|
||||||
if (key == "all_pos_1") {
|
|
||||||
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
|
GGML_ASSERT(data_msg.size() == sizeof(meta->all_pos_1));
|
||||||
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
|
std::memcpy(&(meta->all_pos_1), data_msg.data(), sizeof(meta->all_pos_1));
|
||||||
}
|
}
|
||||||
|
else if (key == "pos") {
|
||||||
|
if (meta->n_tokens > 0) {
|
||||||
|
meta->pos = (llama_pos *) malloc(meta->n_tokens * sizeof(llama_pos));
|
||||||
|
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(llama_pos));
|
||||||
|
std::memcpy(meta->pos, data_msg.data(), meta->n_tokens * sizeof(llama_pos));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (key == "n_seq_id") {
|
||||||
|
if (meta->n_tokens > 0) {
|
||||||
|
meta->n_seq_id = (int32_t *) malloc(data_msg.size());
|
||||||
|
std::memcpy(meta->n_seq_id, data_msg.data(), meta->n_tokens * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// batch.logits
|
||||||
|
else if (key == "seq_id") {
|
||||||
|
if (meta->n_tokens > 0) {
|
||||||
|
const int32_t n_tokens = meta->n_tokens;
|
||||||
|
const int32_t n_seq_max = 1;
|
||||||
|
|
||||||
|
GGML_ASSERT(data_msg.size() == (size_t)n_tokens * n_seq_max * sizeof(llama_seq_id));
|
||||||
|
|
||||||
|
meta->seq_id = (llama_seq_id **) malloc(n_tokens * sizeof(llama_seq_id *));
|
||||||
|
|
||||||
|
const llama_seq_id * flat_data = (const llama_seq_id *)data_msg.data();
|
||||||
|
for (int32_t token_idx = 0; token_idx < n_tokens; ++token_idx) {
|
||||||
|
meta->seq_id[token_idx] = (llama_seq_id *) malloc(n_seq_max * sizeof(llama_seq_id));
|
||||||
|
std::memcpy(meta->seq_id[token_idx], flat_data + token_idx * n_seq_max, n_seq_max * sizeof(llama_seq_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (key == "logits") {
|
||||||
|
if (meta->n_tokens > 0) {
|
||||||
|
GGML_ASSERT(data_msg.size() == meta->n_tokens * sizeof(int8_t));
|
||||||
|
meta->logits = (int8_t *) malloc(meta->n_tokens * sizeof(int8_t));
|
||||||
|
std::memcpy(meta->logits, data_msg.data(), meta->n_tokens * sizeof(int8_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -18191,6 +18356,8 @@ static int llama_decode_internal(
|
||||||
const uint32_t n_world = cparams.n_world;
|
const uint32_t n_world = cparams.n_world;
|
||||||
const uint32_t my_rank = cparams.rank;
|
const uint32_t my_rank = cparams.rank;
|
||||||
|
|
||||||
|
const bool is_perplexity_mode = cparams.is_perplexity_eval;
|
||||||
|
|
||||||
const uint32_t n_tokens_all = batch_all.n_tokens;
|
const uint32_t n_tokens_all = batch_all.n_tokens;
|
||||||
const int64_t n_embd = hparams.n_embd; // used for reserving embeddings space size
|
const int64_t n_embd = hparams.n_embd; // used for reserving embeddings space size
|
||||||
|
|
||||||
|
@ -18243,12 +18410,13 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare for send and receive of metadata
|
// prepare for send and receive of metadata
|
||||||
|
if (!is_perplexity_mode) {
|
||||||
sync_meta meta;
|
sync_meta meta;
|
||||||
meta.n_ctx = cparams.n_ctx;
|
meta.n_ctx = cparams.n_ctx;
|
||||||
bool is_last_dev = (my_rank == n_world - 1);
|
bool is_last_dev = (my_rank == n_world - 1);
|
||||||
|
|
||||||
if (my_rank != 0) {
|
if (my_rank != 0) {
|
||||||
if (llama_recv_meta(&lctx, &meta) == -1) {
|
if (llama_recv_meta(&lctx, &meta, false) == -1) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18303,12 +18471,14 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (!is_last_dev) {
|
if (!is_last_dev) {
|
||||||
meta.n_tokens = batch_all.n_tokens;
|
meta.n_tokens = batch_all.n_tokens;
|
||||||
meta.pos = batch_all.pos;
|
meta.pos = batch_all.pos;
|
||||||
meta.all_pos_0 = batch_all.all_pos_0;
|
meta.all_pos_0 = batch_all.all_pos_0;
|
||||||
meta.all_pos_1 = batch_all.all_pos_1;
|
meta.all_pos_1 = batch_all.all_pos_1;
|
||||||
llama_send_meta(&lctx, &meta);
|
llama_send_meta(&lctx, &meta, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch_all, n_embd,
|
lctx.sbatch.from_batch(batch_all, n_embd,
|
||||||
|
@ -20281,6 +20451,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.no_perf =*/ true,
|
/*.no_perf =*/ true,
|
||||||
/*.abort_callback =*/ nullptr,
|
/*.abort_callback =*/ nullptr,
|
||||||
/*.abort_callback_data =*/ nullptr,
|
/*.abort_callback_data =*/ nullptr,
|
||||||
|
/*.is_perplexity_mode =*/ false
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -20444,10 +20615,15 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->sock_context = new zmq::context_t(2);
|
ctx->sock_context = new zmq::context_t(2);
|
||||||
|
|
||||||
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||||
ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
ctx->recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
||||||
ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
ctx->signal_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
||||||
|
|
||||||
|
// Reverse pipeline sockets (new - for barriers)
|
||||||
|
ctx->reverse_send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||||
|
ctx->reverse_recv_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::pull);
|
||||||
|
|
||||||
if (my_rank != 0 && my_rank != (n_world - 1)) {
|
if (my_rank != 0 && my_rank != (n_world - 1)) {
|
||||||
ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
ctx->master_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||||
} else if (my_rank == (n_world - 1)) {
|
} else if (my_rank == (n_world - 1)) {
|
||||||
|
@ -20455,18 +20631,38 @@ void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t m
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t next_rank = (my_rank + 1) % n_world;
|
const uint32_t next_rank = (my_rank + 1) % n_world;
|
||||||
|
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
|
||||||
|
|
||||||
std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port));
|
std::string recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port));
|
||||||
std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
|
std::string send_endp = "tcp://" + ctx->next_node_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
|
||||||
std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port));
|
std::string master_endp = "tcp://" + ctx->master_ip + ":" + std::to_string(map_rank_to_port(0, ctx->data_port));
|
||||||
std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port));
|
std::string signal_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->signal_port));
|
||||||
|
|
||||||
|
// Reverse pipeline endpoints (new)
|
||||||
|
// Use a different port offset for reverse communication to avoid conflicts
|
||||||
|
const uint32_t reverse_port_offset = 1000;
|
||||||
|
std::string reverse_recv_endp = "tcp://*:" + std::to_string(map_rank_to_port(my_rank, ctx->data_port + reverse_port_offset));
|
||||||
|
std::string reverse_send_endp = "tcp://" + ctx->prev_node_ip + ":" + std::to_string(map_rank_to_port(prev_rank, ctx->data_port + reverse_port_offset));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ctx->recv_socket->bind(recv_endp);
|
ctx->recv_socket->bind(recv_endp);
|
||||||
ctx->signal_socket->bind(signal_endp);
|
ctx->signal_socket->bind(signal_endp);
|
||||||
|
|
||||||
ctx->send_socket->connect(send_endp);
|
ctx->send_socket->connect(send_endp);
|
||||||
if (ctx->master_socket && my_rank != (n_world - 1)) {
|
if (ctx->master_socket && my_rank != (n_world - 1)) {
|
||||||
ctx->master_socket->connect(master_endp);
|
ctx->master_socket->connect(master_endp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup reverse pipeline sockets
|
||||||
|
if (my_rank > 0) {
|
||||||
|
// All ranks except rank 0 can send to previous rank
|
||||||
|
ctx->reverse_send_socket->connect(reverse_send_endp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (my_rank < n_world - 1) {
|
||||||
|
// All ranks except last rank can receive from next rank
|
||||||
|
ctx->reverse_recv_socket->bind(reverse_recv_endp);
|
||||||
|
}
|
||||||
} catch (const zmq::error_t &e) {
|
} catch (const zmq::error_t &e) {
|
||||||
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
|
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
|
||||||
exit(1);
|
exit(1);
|
||||||
|
@ -20740,6 +20936,7 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
||||||
const uint32_t my_rank = ctx->cparams.rank;
|
const uint32_t my_rank = ctx->cparams.rank;
|
||||||
// to adapt to the new topology, use old next_rank
|
// to adapt to the new topology, use old next_rank
|
||||||
const uint32_t next_rank = ctx->cparams.original_next_rank;
|
const uint32_t next_rank = ctx->cparams.original_next_rank;
|
||||||
|
const uint32_t prev_rank = (my_rank - 1 + n_world) % n_world;
|
||||||
|
|
||||||
if (n_world == 1) {
|
if (n_world == 1) {
|
||||||
return;
|
return;
|
||||||
|
@ -20763,6 +20960,89 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
||||||
*msg = new char[msg_str.size() + 1];
|
*msg = new char[msg_str.size() + 1];
|
||||||
std::strcpy(*msg, msg_str.c_str());
|
std::strcpy(*msg, msg_str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send shutdown signal through reverse pipeline as well
|
||||||
|
if (my_rank == n_world - 1) {
|
||||||
|
// Last rank initiates reverse shutdown
|
||||||
|
try {
|
||||||
|
sync_meta shutdown_meta;
|
||||||
|
shutdown_meta.chunk_done = true; // Reuse chunk_done as shutdown signal
|
||||||
|
llama_send_meta(ctx, &shutdown_meta, true); // reverse = true
|
||||||
|
} catch (const zmq::error_t &e) {
|
||||||
|
LLAMA_LOG_INFO("Error sending reverse shutdown signal: %s", e.what());
|
||||||
|
}
|
||||||
|
} else if (my_rank > 0) {
|
||||||
|
// Intermediate ranks relay reverse shutdown signal
|
||||||
|
try {
|
||||||
|
sync_meta shutdown_meta;
|
||||||
|
// Set a short timeout for shutdown
|
||||||
|
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, 500);
|
||||||
|
|
||||||
|
if (llama_recv_meta(ctx, &shutdown_meta, true) == 0) {
|
||||||
|
if (my_rank > 0) {
|
||||||
|
llama_send_meta(ctx, &shutdown_meta, true); // relay upstream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset timeout
|
||||||
|
ctx->reverse_recv_socket->set(zmq::sockopt::rcvtimeo, -1);
|
||||||
|
} catch (const zmq::error_t &e) {
|
||||||
|
LLAMA_LOG_INFO("Error handling reverse shutdown signal on rank %d: %s", my_rank, e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Close signal sender (local socket created in this function)
|
||||||
|
signal_sender.close();
|
||||||
|
|
||||||
|
// Close reverse sockets first
|
||||||
|
if (ctx->reverse_send_socket) {
|
||||||
|
ctx->reverse_send_socket->close();
|
||||||
|
delete ctx->reverse_send_socket;
|
||||||
|
ctx->reverse_send_socket = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->reverse_recv_socket) {
|
||||||
|
ctx->reverse_recv_socket->close();
|
||||||
|
delete ctx->reverse_recv_socket;
|
||||||
|
ctx->reverse_recv_socket = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close existing forward sockets
|
||||||
|
if (ctx->send_socket) {
|
||||||
|
ctx->send_socket->close();
|
||||||
|
delete ctx->send_socket;
|
||||||
|
ctx->send_socket = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->recv_socket) {
|
||||||
|
ctx->recv_socket->close();
|
||||||
|
delete ctx->recv_socket;
|
||||||
|
ctx->recv_socket = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->signal_socket) {
|
||||||
|
ctx->signal_socket->close();
|
||||||
|
delete ctx->signal_socket;
|
||||||
|
ctx->signal_socket = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle master_socket cleanup (be careful not to double-delete)
|
||||||
|
if (ctx->master_socket && my_rank != (n_world - 1) && ctx->master_socket != ctx->send_socket) {
|
||||||
|
ctx->master_socket->close();
|
||||||
|
delete ctx->master_socket;
|
||||||
|
ctx->master_socket = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup ZMQ context last
|
||||||
|
if (ctx->sock_context) {
|
||||||
|
delete ctx->sock_context;
|
||||||
|
ctx->sock_context = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (const zmq::error_t &e) {
|
||||||
|
LLAMA_LOG_INFO("Error cleaning up sockets: %s", e.what());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
|
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
|
||||||
|
@ -20789,6 +21069,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->cparams.rank = params.rank;
|
ctx->cparams.rank = params.rank;
|
||||||
ctx->cparams.force = params.force;
|
ctx->cparams.force = params.force;
|
||||||
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
|
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
|
||||||
|
ctx->cparams.is_perplexity_eval = params.is_perplexity_eval;
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue