mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 10:09:04 +00:00
feat(comm): Enhance robustness of communication
This commit is contained in:
parent
50a916f123
commit
663ad2896d
2 changed files with 45 additions and 7 deletions
|
@ -632,18 +632,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
{
|
{
|
||||||
// synvhronize the KV cache clear signal across all ranks
|
// synchronize the KV cache clear signal across all ranks
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
sync_meta clear_meta;
|
sync_meta clear_meta;
|
||||||
clear_meta.clear_kv_cache = true;
|
clear_meta.clear_kv_cache = true;
|
||||||
|
|
||||||
if (my_rank == 0) {
|
if (my_rank == 0) {
|
||||||
|
// Rank 0 sends the signal
|
||||||
llama_send_meta(ctx, &clear_meta);
|
llama_send_meta(ctx, &clear_meta);
|
||||||
} else {
|
} else {
|
||||||
if (llama_recv_meta(ctx, &clear_meta) == -1) {
|
// Non-rank 0 devices receive the signal with retry mechanism
|
||||||
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank);
|
int retry_count = 0;
|
||||||
|
const int max_retries = 5;
|
||||||
|
bool recv_success = false;
|
||||||
|
|
||||||
|
while (retry_count < max_retries && !recv_success) {
|
||||||
|
if (llama_recv_meta(ctx, &clear_meta) == 0) {
|
||||||
|
recv_success = true;
|
||||||
|
} else {
|
||||||
|
retry_count++;
|
||||||
|
LOG_WRN("Failed to recv clear_kv_cache signal on rank %d, retry %d/%d\n",
|
||||||
|
my_rank, retry_count, max_retries);
|
||||||
|
// Small delay before retry
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!recv_success) {
|
||||||
|
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d after %d retries\n",
|
||||||
|
my_rank, max_retries);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Forward the signal to next rank if not the last device
|
||||||
if (!is_last_dev) {
|
if (!is_last_dev) {
|
||||||
llama_send_meta(ctx, &clear_meta);
|
llama_send_meta(ctx, &clear_meta);
|
||||||
}
|
}
|
||||||
|
@ -707,9 +728,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
// Chain forwarding pattern: consistent with tokens_size communication
|
// Chain forwarding pattern: consistent with tokens_size communication
|
||||||
if (n_world > 1) {
|
if (n_world > 1) {
|
||||||
if (my_rank != 0) {
|
if (my_rank != 0) {
|
||||||
// Non-rank 0 devices receive batch meta data
|
// Non-rank 0 devices receive batch meta data with retry mechanism
|
||||||
if (llama_recv_meta(ctx, &meta) == -1) {
|
int retry_count = 0;
|
||||||
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank);
|
const int max_retries = 3;
|
||||||
|
bool recv_success = false;
|
||||||
|
|
||||||
|
while (retry_count < max_retries && !recv_success) {
|
||||||
|
if (llama_recv_meta(ctx, &meta) == 0) {
|
||||||
|
recv_success = true;
|
||||||
|
} else {
|
||||||
|
retry_count++;
|
||||||
|
LOG_WRN("Failed to recv batch meta on rank %d, retry %d/%d\n",
|
||||||
|
my_rank, retry_count, max_retries);
|
||||||
|
// Small delay before retry
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!recv_success) {
|
||||||
|
LOG_ERR("Failed to recv batch meta on rank %d after %d retries\n",
|
||||||
|
my_rank, max_retries);
|
||||||
return {tokens, -1.0, {}, {}};
|
return {tokens, -1.0, {}, {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18001,7 +18001,7 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
|
||||||
zmq::socket_t * recv_socket = ctx->recv_socket;
|
zmq::socket_t * recv_socket = ctx->recv_socket;
|
||||||
GGML_ASSERT(recv_socket != nullptr);
|
GGML_ASSERT(recv_socket != nullptr);
|
||||||
recv_socket->set(zmq::sockopt::rcvtimeo, 1000);
|
recv_socket->set(zmq::sockopt::rcvtimeo, 5000); // Increase timeout to 5 seconds
|
||||||
|
|
||||||
std::vector<zmq::message_t> recv_msgs;
|
std::vector<zmq::message_t> recv_msgs;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue