add topo rebuild

Signed-off-by: DeEMO <yzzxrx@gmail.com>
This commit is contained in:
DeEMO 2025-05-15 04:22:12 +00:00
parent 26bb86c09b
commit fdd6694633
4 changed files with 98 additions and 1 deletions

View file

@ -20329,6 +20329,92 @@ int llama_bcast_layer_setup(struct llama_context * ctx, uint32_t * n_layer_windo
return 0;
}
LLAMA_API int llama_rebuild_topo(llama_context *ctx,
uint32_t *n_layer_window,
device_info *dev_info_set) {
uint32_t n_world = ctx->cparams.n_world;
uint32_t my_rank = ctx->cparams.rank;
std::vector<zmq::message_t> msgs;
device_info* dev_info_ptr = nullptr;
if (dev_info_set == nullptr){
// for rank!=0, recv all devices info
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) {
return -1;
}
dev_info_ptr = new device_info[n_world];
for (size_t i = 0; i < msgs.size(); i++) {
deserialize((const char *)msgs[i].data(), &dev_info_set[i]);
}
}else{
char * buffer = nullptr;
for(size_t i = 0; i < n_world; i++) {
size_t buffer_size = serialize(&dev_info_set[i], &buffer);
msgs.emplace_back(buffer, buffer_size);
free(buffer);
}
dev_info_ptr = dev_info_set;
}
GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr);
GGML_ASSERT(msgs.size() == n_world);
// notify next rank
auto next_rank = (my_rank + 1) % n_world;
if(n_layer_window[next_rank] <= 0){
try {
ctx->send_socket->setsockopt(ZMQ_LINGER, 3500);
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t& e) {
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
if(!dev_info_set){
delete[] dev_info_ptr;
}
return -1;
}
}
// check myself's layer
auto* socket_to_close = ctx->send_socket;
if(n_layer_window[my_rank] > 0) {
// reconstruct socket to the next valid rank
std::string next_ip;
auto current_rank = my_rank;
while(next_rank!=my_rank){
if(n_layer_window[next_rank] > 0){
next_ip = dev_info_ptr[next_rank].next_ip;
break;
}
next_rank = (next_rank + 1) % n_world;
current_rank = (current_rank + 1) % n_world;
}
if(!next_ip.empty()){
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
ctx->next_node_ip = next_ip;
try {
ctx->send_socket->connect(send_endp);
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
if(!dev_info_set){
delete[] dev_info_ptr;
}
return -1;
}
}
}
if(!dev_info_set){
delete[] dev_info_ptr;
}
socket_to_close->close();
delete socket_to_close;
if(n_layer_window[my_rank]<=0){
exit(0);
}
return true;
}
int llama_recv_layer_setup(struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers) {
uint32_t n_world = ctx->cparams.n_world;
uint32_t my_rank = ctx->cparams.rank;