diff --git a/src/llama.cpp b/src/llama.cpp index f083e8e5..121a00b6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20383,7 +20383,7 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, } // check myself's layer - auto* socket_to_close = ctx->send_socket; + zmq::socket_t* socket_to_close = nullptr; if(n_layer_window[my_rank] > 0) { // reconstruct socket to the next valid rank std::string next_ip; @@ -20397,20 +20397,25 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, current_rank = (current_rank + 1) % n_world; } if(!next_ip.empty()){ - ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); - std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); - ctx->next_node_ip = next_ip; - ctx->cparams.original_next_rank = next_rank; - try { + if((my_rank+1)%n_world != next_rank){ + socket_to_close = ctx->send_socket; + ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); + std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); ctx->send_socket->connect(send_endp); - auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); - zmq::send_multipart(*ctx->send_socket, msgs); - } catch (const zmq::error_t &e) { - LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); - if(!dev_info_set){ - delete[] dev_info_ptr; + ctx->next_node_ip = next_ip; + ctx->cparams.original_next_rank = next_rank; + } + if(next_rank != 0){ + try { + auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t &e) { + LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); + if(!dev_info_set){ + delete[] dev_info_ptr; + } + return -1; } - return -1; } }else{ // only one node @@ -20420,8 +20425,10 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, if(!dev_info_set){ delete[] dev_info_ptr; } - socket_to_close->close(); - delete socket_to_close; + if(socket_to_close != nullptr){ + socket_to_close->close(); + delete socket_to_close; + } return 0; }