feat(comm): Enhance robustness of communication

This commit is contained in:
leeetao 2025-07-19 07:57:57 +00:00
parent 50a916f123
commit 663ad2896d
2 changed files with 45 additions and 7 deletions

View file

@ -632,18 +632,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
{ {
// synvhronize the KV cache clear signal across all ranks // synchronize the KV cache clear signal across all ranks
if (n_world > 1) { if (n_world > 1) {
sync_meta clear_meta; sync_meta clear_meta;
clear_meta.clear_kv_cache = true; clear_meta.clear_kv_cache = true;
if (my_rank == 0) { if (my_rank == 0) {
// Rank 0 sends the signal
llama_send_meta(ctx, &clear_meta); llama_send_meta(ctx, &clear_meta);
} else { } else {
if (llama_recv_meta(ctx, &clear_meta) == -1) { // Non-rank 0 devices receive the signal with retry mechanism
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d\n", my_rank); int retry_count = 0;
const int max_retries = 5;
bool recv_success = false;
while (retry_count < max_retries && !recv_success) {
if (llama_recv_meta(ctx, &clear_meta) == 0) {
recv_success = true;
} else {
retry_count++;
LOG_WRN("Failed to recv clear_kv_cache signal on rank %d, retry %d/%d\n",
my_rank, retry_count, max_retries);
// Small delay before retry
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
}
if (!recv_success) {
LOG_ERR("Failed to recv clear_kv_cache signal on rank %d after %d retries\n",
my_rank, max_retries);
return {tokens, -1.0, {}, {}}; return {tokens, -1.0, {}, {}};
} }
// Forward the signal to next rank if not the last device
if (!is_last_dev) { if (!is_last_dev) {
llama_send_meta(ctx, &clear_meta); llama_send_meta(ctx, &clear_meta);
} }
@ -707,9 +728,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Chain forwarding pattern: consistent with tokens_size communication // Chain forwarding pattern: consistent with tokens_size communication
if (n_world > 1) { if (n_world > 1) {
if (my_rank != 0) { if (my_rank != 0) {
// Non-rank 0 devices receive batch meta data // Non-rank 0 devices receive batch meta data with retry mechanism
if (llama_recv_meta(ctx, &meta) == -1) { int retry_count = 0;
LOG_ERR("Failed to recv batch meta on rank %d\n", my_rank); const int max_retries = 3;
bool recv_success = false;
while (retry_count < max_retries && !recv_success) {
if (llama_recv_meta(ctx, &meta) == 0) {
recv_success = true;
} else {
retry_count++;
LOG_WRN("Failed to recv batch meta on rank %d, retry %d/%d\n",
my_rank, retry_count, max_retries);
// Small delay before retry
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
}
if (!recv_success) {
LOG_ERR("Failed to recv batch meta on rank %d after %d retries\n",
my_rank, max_retries);
return {tokens, -1.0, {}, {}}; return {tokens, -1.0, {}, {}};
} }

View file

@ -18001,7 +18001,7 @@ void llama_send_meta(llama_context * ctx, struct sync_meta * meta) {
int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) { int llama_recv_meta(llama_context * ctx, struct sync_meta * meta) {
zmq::socket_t * recv_socket = ctx->recv_socket; zmq::socket_t * recv_socket = ctx->recv_socket;
GGML_ASSERT(recv_socket != nullptr); GGML_ASSERT(recv_socket != nullptr);
recv_socket->set(zmq::sockopt::rcvtimeo, 1000); recv_socket->set(zmq::sockopt::rcvtimeo, 5000); // Increase timeout to 5 seconds
std::vector<zmq::message_t> recv_msgs; std::vector<zmq::message_t> recv_msgs;