diff --git a/common/common.cpp b/common/common.cpp index 021f1640..991b34a5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1660,6 +1660,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { // get device profile LOG_INF("\nstart profiling this device, this may take some seconds ...\n"); dev_info.rank = params.rank; + dev_info.next_ip = params.next_node_ip.c_str(); if (n_world > 1) { llama_profile_device(&dev_info, model, ml, params.gpu_mem, params.n_predict, params.n_ctx, params.cpuparams.n_threads, params.flash_attn); } @@ -1682,6 +1683,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { return iparams; } llama_bcast_layer_setup(lctx, n_layer_window, n_gpu_layers); + + //rebuild topo + llama_rebuild_topo(lctx, n_layer_window, dev_info_set.data()); } else { // use the user-defined n_layer_window std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), n_layer_window); @@ -1690,8 +1694,12 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } else { if (auto_schedule){ llama_send_device_info(lctx, &dev_info); + llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); + // rebuild topo + llama_rebuild_topo(lctx,n_layer_window, nullptr); + }else{ + llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); } - llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); } // update n_layer_window and n_gpu_layers diff --git a/common/profiler.h b/common/profiler.h index a685ff8c..a3110299 100644 --- a/common/profiler.h +++ b/common/profiler.h @@ -320,6 +320,7 @@ struct device_info { uint32_t rank; const char * device_name; const char * device_os; + const char * next_ip; struct disk_props disk; struct cpu_props cpu_props; struct memory_info memory; @@ -333,6 +334,7 @@ struct device_info { rank(0), device_name(""), device_os(""), + next_ip(""), disk(), cpu_props(), memory(), diff --git a/include/llama.h b/include/llama.h index 9f3da708..515dfd93 100644 --- a/include/llama.h +++ b/include/llama.h @@ -455,6 +455,7 @@ extern "C" { LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); LLAMA_API int llama_bcast_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers); + LLAMA_API int llama_rebuild_topo (struct llama_context * ctx, uint32_t * n_layer_window, struct device_info * dev_info_set); LLAMA_API int llama_recv_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers); LLAMA_API int llm_load_tensors( diff --git a/src/llama.cpp b/src/llama.cpp index 87ae83ac..8cc37213 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20329,6 +20329,92 @@ int llama_bcast_layer_setup(struct llama_context * ctx, uint32_t * n_layer_windo return 0; } +LLAMA_API int llama_rebuild_topo(llama_context *ctx, + uint32_t *n_layer_window, + device_info *dev_info_set) { + uint32_t n_world = ctx->cparams.n_world; + uint32_t my_rank = ctx->cparams.rank; + std::vector msgs; + device_info* dev_info_ptr = nullptr; + if (dev_info_set == nullptr){ + // for rank!=0, recv all devices info + if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { + return -1; + } + dev_info_ptr = new device_info[n_world]; + for (size_t i = 0; i < msgs.size(); i++) { + deserialize((const char *)msgs[i].data(), &dev_info_set[i]); + } + }else{ + char * buffer = nullptr; + for(size_t i = 0; i < n_world; i++) { + size_t buffer_size = serialize(&dev_info_set[i], &buffer); + msgs.emplace_back(buffer, buffer_size); + + free(buffer); + } + dev_info_ptr = dev_info_set; + } + + GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr); + GGML_ASSERT(msgs.size() == n_world); + + // notify next rank + auto next_rank = (my_rank + 1) % n_world; + if(n_layer_window[next_rank] <= 0){ + try { + ctx->send_socket->setsockopt(ZMQ_LINGER, 3500); + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t& e) { + LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); + if(!dev_info_set){ + delete[] dev_info_ptr; + } + return -1; + } + } + + // check myself's layer + auto* socket_to_close = ctx->send_socket; + if(n_layer_window[my_rank] > 0) { + // reconstruct socket to the next valid rank + std::string next_ip; + auto current_rank = my_rank; + while(next_rank!=my_rank){ + if(n_layer_window[next_rank] > 0){ + next_ip = dev_info_ptr[next_rank].next_ip; + break; + } + next_rank = (next_rank + 1) % n_world; + current_rank = (current_rank + 1) % n_world; + } + if(!next_ip.empty()){ + ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); + std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); + ctx->next_node_ip = next_ip; + try { + ctx->send_socket->connect(send_endp); + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t &e) { + LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); + if(!dev_info_set){ + delete[] dev_info_ptr; + } + return -1; + } + } + } + if(!dev_info_set){ + delete[] dev_info_ptr; + } + socket_to_close->close(); + delete socket_to_close; + if(n_layer_window[my_rank]<=0){ + exit(0); + } + return true; +} + int llama_recv_layer_setup(struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers) { uint32_t n_world = ctx->cparams.n_world; uint32_t my_rank = ctx->cparams.rank;