topo rebuild: add a delay to avoid packet interleaving

This commit is contained in:
Li, Zonghang 2025-06-26 14:47:34 +04:00
parent 50807fd4e1
commit 729870fcd7
3 changed files with 40 additions and 21 deletions

View file

@ -1788,6 +1788,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
return iparams; return iparams;
} }
llama_bcast_layer_setup(lctx, n_layer_window, n_gpu_layers); llama_bcast_layer_setup(lctx, n_layer_window, n_gpu_layers);
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // add a delay to avoid packet interleaving
llama_rebuild_topo(lctx, n_layer_window, dev_info_set.data(), &node_type, is_forwarder); llama_rebuild_topo(lctx, n_layer_window, dev_info_set.data(), &node_type, is_forwarder);
} else { } else {
// use the user-defined n_layer_window // use the user-defined n_layer_window
@ -1798,6 +1799,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
if (auto_schedule){ if (auto_schedule){
llama_send_device_info(lctx, &dev_info); llama_send_device_info(lctx, &dev_info);
llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers);
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // add a delay to avoid packet interleaving
llama_rebuild_topo (lctx, n_layer_window, nullptr, &node_type, is_forwarder); llama_rebuild_topo (lctx, n_layer_window, nullptr, &node_type, is_forwarder);
} else { } else {
llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers);

View file

@ -2837,7 +2837,7 @@ size_t deserialize(const char * buffer, struct device_info * dev_info) {
return ptr - buffer; return ptr - buffer;
} }
void TopoRebuildHelperInfo::deserialize(const char *buffer) { void TopoRebuildHelperInfo::deserialize(const char * buffer) {
size_t buffer_size = ::deserialize(buffer, &dev_info); size_t buffer_size = ::deserialize(buffer, &dev_info);
if (buffer_size == 0) { if (buffer_size == 0) {
LOG_ERR("%s: failed to deserialize device info\n", __func__); LOG_ERR("%s: failed to deserialize device info\n", __func__);
@ -2846,14 +2846,16 @@ void TopoRebuildHelperInfo::deserialize(const char *buffer) {
memcpy(&is_forwarder, buffer + buffer_size, 1); memcpy(&is_forwarder, buffer + buffer_size, 1);
} }
size_t TopoRebuildHelperInfo::serialize(char **buffer) const{ size_t TopoRebuildHelperInfo::serialize(char ** buffer) const{
size_t buffer_size = ::serialize(&dev_info, buffer); size_t buffer_size = ::serialize(&dev_info, buffer);
char* buffer_ = (char*)malloc(buffer_size+1); char * buffer_ = (char *)malloc(buffer_size + 1);
if (buffer_ == NULL) { if (buffer_ == NULL) {
LOG_ERR("%s: failed to allocate %zu bytes for device info serialization\n", LOG_ERR("%s: failed to allocate %zu bytes for device info serialization\n",
__func__, buffer_size); __func__, buffer_size);
return 0; return 0;
} }
memcpy(buffer_, *buffer, buffer_size); memcpy(buffer_, *buffer, buffer_size);
memcpy(buffer_ + buffer_size, &is_forwarder, 1); memcpy(buffer_ + buffer_size, &is_forwarder, 1);
free(*buffer); free(*buffer);

View file

@ -3633,6 +3633,10 @@ void llama_profile_device(
dev_info->memory.total_physical = round(device_physical_memory(false) / (double)(1 << 30) * 100) / 100; dev_info->memory.total_physical = round(device_physical_memory(false) / (double)(1 << 30) * 100) / 100;
dev_info->memory.available_physical = round(device_physical_memory(true) / (double)(1 << 30) * 100) / 100; dev_info->memory.available_physical = round(device_physical_memory(true) / (double)(1 << 30) * 100) / 100;
GGML_ASSERT(dev_info->memory.total_physical > 0, "Failed to parse total physical memory\n");
GGML_ASSERT(dev_info->memory.available_physical > 0, "Failed to parse available physical memory\n");
dev_info->memory.used_can_swap = round(device_swappable_memory() / (double)(1 << 30) * 100) / 100; dev_info->memory.used_can_swap = round(device_swappable_memory() / (double)(1 << 30) * 100) / 100;
dev_info->memory.total_swap = round(device_swap_memory(false) / (double)(1 << 30) * 100) / 100; dev_info->memory.total_swap = round(device_swap_memory(false) / (double)(1 << 30) * 100) / 100;
dev_info->memory.available_swap = round(device_swap_memory(true) / (double)(1 << 30) * 100) / 100; dev_info->memory.available_swap = round(device_swap_memory(true) / (double)(1 << 30) * 100) / 100;
@ -20658,8 +20662,8 @@ int llama_rebuild_topo(llama_context * ctx,
uint32_t my_rank = ctx->cparams.rank; uint32_t my_rank = ctx->cparams.rank;
TopoRebuildHelperInfo* topo_helper = new TopoRebuildHelperInfo[n_world]; TopoRebuildHelperInfo* topo_helper = new TopoRebuildHelperInfo[n_world];
if (dev_info_set == nullptr){ if (dev_info_set == nullptr) {
// for rank!=0, recv all devices info // for rank != 0, recv all devices info
std::vector<zmq::message_t> msgs; std::vector<zmq::message_t> msgs;
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) {
return -1; return -1;
@ -20671,7 +20675,7 @@ int llama_rebuild_topo(llama_context * ctx,
} else { } else {
for (size_t i = 0; i < n_world; i++) { for (size_t i = 0; i < n_world; i++) {
topo_helper[i].dev_info = dev_info_set[i]; topo_helper[i].dev_info = dev_info_set[i];
topo_helper[i].is_forwarder = 0; topo_helper[i].is_forwarder = 0;
} }
} }
@ -20679,29 +20683,32 @@ int llama_rebuild_topo(llama_context * ctx,
auto next_rank = (my_rank + 1) % n_world; auto next_rank = (my_rank + 1) % n_world;
auto next_connect_rank = (my_rank + 1) % n_world; auto next_connect_rank = (my_rank + 1) % n_world;
zmq::socket_t* socket_to_close = nullptr; zmq::socket_t * socket_to_close = nullptr;
bool is_not_exit = n_layer_window[my_rank] > 0 || topo_helper[my_rank].is_forwarder == 1; bool is_not_exit = n_layer_window[my_rank] > 0 || topo_helper[my_rank].is_forwarder == 1;
if (is_not_exit) { if (is_not_exit) {
// reconstruct socket to the next valid rank // reconstruct socket to the next valid rank
auto current_rank = my_rank; auto current_rank = my_rank;
std::vector<uint32_t> nodes; std::vector<uint32_t> nodes;
auto next_rank_ = next_rank; auto next_rank_ = next_rank;
while (next_rank_ != my_rank) { while (next_rank_ != my_rank) {
nodes.push_back(next_rank_); nodes.push_back(next_rank_);
if (n_layer_window[next_rank_] > 0) { if (n_layer_window[next_rank_] > 0) {
break; break;
} }
next_rank_ = (next_rank_ + 1) % n_world; next_rank_ = (next_rank_ + 1) % n_world;
current_rank = (current_rank + 1) % n_world; current_rank = (current_rank + 1) % n_world;
} }
if (next_rank_ == my_rank) { if (next_rank_ == my_rank) {
// only one node // only one node
ctx->next_node_ip = ""; ctx->next_node_ip = "";
socket_to_close = ctx->send_socket; socket_to_close = ctx->send_socket;
ctx->send_socket = nullptr; ctx->send_socket = nullptr;
} else { } else {
// iterate node reverse // iterate node reverse
zmq::socket_t* socket = nullptr; zmq::socket_t * socket = nullptr;
std::string ip; std::string ip;
for (int i = nodes.size() - 1; i > 0; --i) { for (int i = nodes.size() - 1; i > 0; --i) {
auto rank = nodes[i]; auto rank = nodes[i];
@ -20716,13 +20723,13 @@ int llama_rebuild_topo(llama_context * ctx,
// reset socket // reset socket
GGML_ASSERT(socket != nullptr); GGML_ASSERT(socket != nullptr);
GGML_ASSERT(!ip.empty()); GGML_ASSERT(!ip.empty());
socket_to_close = ctx->send_socket; socket_to_close = ctx->send_socket;
ctx->send_socket = socket; ctx->send_socket = socket;
ctx->next_node_ip = ip; ctx->next_node_ip = ip;
ctx->cparams.original_next_rank = next_connect_rank; ctx->cparams.original_next_rank = next_connect_rank;
} }
} }
}else if (n_layer_window[next_rank] <= 0 && topo_helper[next_rank].is_forwarder == 0) { } else if (n_layer_window[next_rank] <= 0 && topo_helper[next_rank].is_forwarder == 0) {
socket_to_close = ctx->send_socket; socket_to_close = ctx->send_socket;
} }
@ -20733,7 +20740,7 @@ int llama_rebuild_topo(llama_context * ctx,
auto msgs = topohelper_to_messages(topo_helper, n_world); auto msgs = topohelper_to_messages(topo_helper, n_world);
socket_to_close->set(zmq::sockopt::linger, 3500); socket_to_close->set(zmq::sockopt::linger, 3500);
zmq::send_multipart(*socket_to_close, msgs); zmq::send_multipart(*socket_to_close, msgs);
} catch (const zmq::error_t& e) { } catch (const zmq::error_t & e) {
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
return -1; return -1;
} }
@ -20745,7 +20752,7 @@ int llama_rebuild_topo(llama_context * ctx,
try { try {
auto msgs = topohelper_to_messages(topo_helper, n_world); auto msgs = topohelper_to_messages(topo_helper, n_world);
zmq::send_multipart(*ctx->send_socket, msgs); zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t& e) { } catch (const zmq::error_t & e) {
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
return -1; return -1;
} }
@ -20770,18 +20777,20 @@ int llama_rebuild_topo(llama_context * ctx,
topo_helper[i].deserialize((char *)msgs[i].data()); topo_helper[i].deserialize((char *)msgs[i].data());
} }
// broadcast the whole view // broadcast the whole view
if (next_connect_rank!=0) { if (next_connect_rank != 0) {
try { try {
zmq::send_multipart(*ctx->send_socket, msgs); zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t& e) { } catch (const zmq::error_t & e) {
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
return -1; return -1;
} }
} }
} }
for (size_t i = 0; i < n_world; i++) { for (size_t i = 0; i < n_world; i++) {
is_forwarder[i] = topo_helper[i].is_forwarder; is_forwarder[i] = topo_helper[i].is_forwarder;
} }
ctx->cparams.node_type = *node_type; ctx->cparams.node_type = *node_type;
if (socket_to_close != nullptr) { if (socket_to_close != nullptr) {
@ -20816,8 +20825,14 @@ int llama_recv_layer_setup(struct llama_context * ctx, uint32_t * n_layer_window
uint32_t my_rank = ctx->cparams.rank; uint32_t my_rank = ctx->cparams.rank;
std::vector<zmq::message_t> recv_msgs; std::vector<zmq::message_t> recv_msgs;
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(recv_msgs))) { while (true) {
return -1; recv_msgs.clear();
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(recv_msgs))) {
return -1;
}
if (!recv_msgs.empty() && recv_msgs[0].to_string() == "n_layer_window") {
break;
}
} }
GGML_ASSERT(recv_msgs[0].to_string() == "n_layer_window"); GGML_ASSERT(recv_msgs[0].to_string() == "n_layer_window");
@ -20827,7 +20842,7 @@ int llama_recv_layer_setup(struct llama_context * ctx, uint32_t * n_layer_window
if (recv_msgs.size() > 2) { if (recv_msgs.size() > 2) {
GGML_ASSERT(recv_msgs[2].to_string() == "n_gpu_layers"); GGML_ASSERT(recv_msgs[2].to_string() == "n_gpu_layers");
GGML_ASSERT(recv_msgs[3].size() == sizeof(uint32_t) * 32); GGML_ASSERT(recv_msgs[3].size() == sizeof(uint32_t) * 32);
memcpy(n_gpu_layers, recv_msgs[3].data(), sizeof(uint32_t) * 32); memcpy(n_gpu_layers, recv_msgs[3].data(), sizeof(uint32_t) * 32);
} }
if (my_rank != n_world - 1) { if (my_rank != n_world - 1) {