fix: handle socket closure and connection in llama_rebuild_topo

Signed-off-by: DeEMO <yzzxrx@gmail.com>
This commit is contained in:
DeEMO 2025-05-16 20:48:51 +08:00
parent 8b61cb2fa4
commit 34eaa8224d

View file

@ -20383,7 +20383,7 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
}
// check myself's layer
auto* socket_to_close = ctx->send_socket;
zmq::socket_t* socket_to_close = nullptr;
if(n_layer_window[my_rank] > 0) {
// reconstruct socket to the next valid rank
std::string next_ip;
@ -20397,20 +20397,25 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
current_rank = (current_rank + 1) % n_world;
}
if(!next_ip.empty()){
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
ctx->next_node_ip = next_ip;
ctx->cparams.original_next_rank = next_rank;
try {
if((my_rank+1)%n_world != next_rank){
socket_to_close = ctx->send_socket;
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
ctx->send_socket->connect(send_endp);
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
if(!dev_info_set){
delete[] dev_info_ptr;
ctx->next_node_ip = next_ip;
ctx->cparams.original_next_rank = next_rank;
}
if(next_rank != 0){
try {
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
if(!dev_info_set){
delete[] dev_info_ptr;
}
return -1;
}
return -1;
}
}else{
// only one node
@ -20420,8 +20425,10 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
if(!dev_info_set){
delete[] dev_info_ptr;
}
socket_to_close->close();
delete socket_to_close;
if(socket_to_close != nullptr){
socket_to_close->close();
delete socket_to_close;
}
return 0;
}