diff --git a/common/common.cpp b/common/common.cpp index dff98506..18e0804f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1717,6 +1717,8 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } // sychronize device profile to the master node + NodeType node_type; + char is_fowarder[32] = {0}; if (my_rank == 0) { if (auto_schedule) { std::vector dev_info_set(n_world); @@ -1743,14 +1745,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (auto_schedule){ llama_send_device_info(lctx, &dev_info); llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); - llama_rebuild_topo (lctx, n_layer_window, nullptr); + llama_rebuild_topo (lctx, n_layer_window, nullptr, &node_type, is_fowarder); } else { llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); } } // if this is a weak device, then exit - if (n_layer_window[my_rank] <= 0) { + if (node_type == NodeType::NODE_TYPE_EXIT) { LOG_INF("No layer is assigned to me, exit.\n"); llama_free(lctx); llama_free_model(model); @@ -1762,7 +1764,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { std::vector n_layer_window_temp = {n_layer_window[0]}, n_gpu_layers_temp = {n_gpu_layers[0]}; for (uint32_t i = 1; i < n_world; i++) { - if (n_layer_window[i] <= 0) { + if (n_layer_window[i] <= 0 && is_fowarder[i] == 0) { continue; } if (i <= my_rank) { @@ -1794,7 +1796,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { n_world = update_n_world; llama_update_context_with_rankworld(lctx, update_rank, update_n_world); - + + if(node_type == NodeType::NODE_TYPE_EXIT){ + //just foward + while (true) { + llama_foward_messages(lctx); + } + } + // update n_layer_window and n_gpu_layers std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window); std::copy(std::begin(n_layer_window), std::end(n_layer_window), cparams.n_layer_window); diff --git a/common/profiler.cpp b/common/profiler.cpp index 0da13824..fa1a56b5 100644 --- a/common/profiler.cpp +++ b/common/profiler.cpp @@ -2621,7 +2621,7 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) { return total_size; } -void deserialize(const char * buffer, struct device_info * dev_info) { +size_t deserialize(const char * buffer, struct device_info * dev_info) { const char * ptr = buffer; // rank @@ -2821,6 +2821,32 @@ void deserialize(const char * buffer, struct device_info * dev_info) { ptr += sizeof(float); memcpy(&dev_info->gpu_props.cuda_mem_cpy_delay, ptr, sizeof(float)); + ptr += sizeof(float); // no need to synchronize model flops and model params -} \ No newline at end of file + return ptr - buffer; +} + +void TopoRebuildHelperInfo::deserialize(const char *buffer) { + size_t buffer_size = ::deserialize(buffer, &dev_info); + if (buffer_size == 0) { + LOG_ERR("%s: failed to deserialize device info\n", __func__); + return; + } + memcpy(&is_fowarder, buffer + buffer_size, 1); +} + +size_t TopoRebuildHelperInfo::serialize(char **buffer) const{ + size_t buffer_size = ::serialize(&dev_info, buffer); + char* buffer_ = (char*)malloc(buffer_size+1); + if (buffer_ == NULL) { + LOG_ERR("%s: failed to allocate %zu bytes for device info serialization\n", + __func__, buffer_size); + return 0; + } + memcpy(buffer_, *buffer, buffer_size); + memcpy(buffer_ + buffer_size, &is_fowarder, 1); + free(*buffer); + *buffer = buffer_; + return buffer_size + 1; +} diff --git a/common/profiler.h b/common/profiler.h index 06741d6c..5ac73a8c 100644 --- a/common/profiler.h +++ b/common/profiler.h @@ -346,6 +346,18 @@ struct device_info { model_bytes() {} }; +struct TopoRebuildHelperInfo{ + struct device_info dev_info; + char is_fowarder; + + TopoRebuildHelperInfo(): + dev_info(), + is_fowarder(0){} + + void deserialize(const char * buffer); + size_t serialize(char ** buffer) const; +}; + enum profiler_backend_type { PROFILER_BACKEND_TYPE_CPU = 0, PROFILER_BACKEND_TYPE_METAL = 1, @@ -389,6 +401,6 @@ int device_has_blas (void); int device_has_sycl (void); size_t serialize (const struct device_info * dev_info, char ** buffer); -void deserialize(const char * buffer, struct device_info * dev_info); +size_t deserialize(const char * buffer, struct device_info * dev_info); #endif // PROFILER_H diff --git a/include/llama.h b/include/llama.h index 86da593c..c2e4d43c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -448,6 +448,12 @@ extern "C" { struct llama_model_params params); LLAMA_API void llama_free_model(struct llama_model * model); + + enum NodeType{ + NODE_TYPE_WORKER, + NODE_TYPE_FOWARDER, + NODE_TYPE_EXIT, + }; LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank); LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg); @@ -455,7 +461,12 @@ extern "C" { LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); LLAMA_API int llama_bcast_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers); - LLAMA_API int llama_rebuild_topo (struct llama_context * ctx, uint32_t * n_layer_window, struct device_info * dev_info_set); + LLAMA_API int llama_rebuild_topo (struct llama_context * ctx, + uint32_t * n_layer_window, + struct device_info * dev_info_set, + NodeType* node_type, + char * is_fowarder); + LLAMA_API int llama_foward_messages (struct llama_context * ctx); LLAMA_API int llama_recv_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers); LLAMA_API int llm_load_tensors( diff --git a/src/llama.cpp b/src/llama.cpp index af42f79d..dd7b0c82 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -173,12 +173,12 @@ static void zeros(std::ofstream & file, size_t n) { } // zmq helpers -static std::vector dev_infos_to_messages(const device_info* infos, - uint32_t n_world){ +static std::vector topohelper_to_messages(const TopoRebuildHelperInfo* infos, + uint32_t n_world){ std::vector res; for (uint32_t i = 0; i < n_world; ++i) { char * buffer = nullptr; - size_t buffer_size = serialize(&infos[i], &buffer); + size_t buffer_size = infos[i].serialize(&buffer); res.emplace_back(buffer, buffer_size); free(buffer); } @@ -20428,6 +20428,39 @@ static uint32_t map_rank_to_port(uint32_t rank, uint32_t data_port) { return data_port + rank; } +static std::string try_connect(llama_context *ctx, uint32_t rank, TopoRebuildHelperInfo* infos, uint32_t n_world, zmq::socket_t** socket){ + auto prv_rank = (rank - 1 + n_world) % n_world; + std::string ip = infos[prv_rank].dev_info.next_ip; + std::string send_endp = "tcp://" + ip + ":" + std::to_string(map_rank_to_port(rank, ctx->data_port)); + *socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); + int events = 0; + try { + (*socket)->set(zmq::sockopt::linger, 0); + (*socket)->set(zmq::sockopt::sndtimeo, 500); + + (*socket)->connect(send_endp); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + size_t events_size = sizeof(events); + (*socket)->getsockopt(ZMQ_EVENTS, &events, &events_size); + + } catch (const zmq::error_t& e) { + delete *socket; + *socket = nullptr; + return ""; + } + + if((events & ZMQ_POLLOUT) != 0){ + return ip; + }else{ + delete *socket; + *socket = nullptr; + return ""; + } + +} + void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t my_rank) { if (n_world == 1) { return; @@ -20602,95 +20635,161 @@ int llama_bcast_layer_setup(struct llama_context * ctx, uint32_t * n_layer_windo return 0; } -int llama_rebuild_topo(llama_context * ctx, uint32_t * n_layer_window, device_info * dev_info_set) { +int llama_rebuild_topo(llama_context * ctx, + uint32_t * n_layer_window, + device_info * dev_info_set, + NodeType * node_type, + char * is_fowarder) { uint32_t n_world = ctx->cparams.n_world; uint32_t my_rank = ctx->cparams.rank; - device_info * dev_info_ptr = nullptr; + TopoRebuildHelperInfo* topo_helper = new TopoRebuildHelperInfo[n_world]; - if (dev_info_set == nullptr) { + if (dev_info_set == nullptr){ + // for rank!=0, recv all devices info std::vector msgs; if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { return -1; } - dev_info_ptr = new device_info[n_world]; for (size_t i = 0; i < msgs.size(); i++) { - deserialize((const char *)msgs[i].data(), &dev_info_ptr[i]); + topo_helper[i].deserialize((char *)msgs[i].data()); } GGML_ASSERT(msgs.size() == n_world); } else { - dev_info_ptr = dev_info_set; + for (size_t i = 0; i < n_world; i++) { + topo_helper[i].dev_info = dev_info_set[i]; + topo_helper[i].is_fowarder = 0; + } } GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr); - // notify next rank auto next_rank = (my_rank + 1) % n_world; - if (n_layer_window[next_rank] <= 0 && next_rank != 0) { + auto next_connect_rank = (my_rank + 1) % n_world; + zmq::socket_t* socket_to_close = nullptr; + bool is_not_exit = n_layer_window[my_rank] > 0 || topo_helper[my_rank].is_fowarder == 1; + if (is_not_exit){ + // reconstruct socket to the next valid rank + auto current_rank = my_rank; + std::vector nodes; + auto next_rank_ = next_rank; + while (next_rank_ != my_rank) { + nodes.push_back(next_rank_); + if (n_layer_window[next_rank_] > 0) { + break; + } + next_rank_ = (next_rank_ + 1) % n_world; + current_rank = (current_rank + 1) % n_world; + } + if (next_rank_ == my_rank) { + // only one node + ctx->next_node_ip = ""; + socket_to_close = ctx->send_socket; + ctx->send_socket = nullptr; + } else { + // iterate node reverse + zmq::socket_t* socket = nullptr; + std::string ip; + for (int i = nodes.size() - 1; i > 0; --i) { + auto rank = nodes[i]; + ip = try_connect(ctx, rank, topo_helper, n_world, &socket); + if(!ip.empty()){ + topo_helper[rank].is_fowarder = 1; + next_connect_rank = rank; + break; + } + } + if(next_connect_rank != next_rank){ + // reset socket + GGML_ASSERT(socket != nullptr); + GGML_ASSERT(!ip.empty()); + socket_to_close = ctx->send_socket; + ctx->send_socket = socket; + ctx->next_node_ip = ip; + ctx->cparams.original_next_rank = next_connect_rank; + } + } + }else if(n_layer_window[next_rank] <= 0 && topo_helper[my_rank].is_fowarder == 0){ + socket_to_close = ctx->send_socket; + } + + // notify next exiting node + if (socket_to_close != nullptr) { + GGML_ASSERT(n_layer_window[next_rank] <= 0 && topo_helper[next_rank].is_fowarder == 0); try { - auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); - ctx->send_socket->set(zmq::sockopt::linger, 3500); - zmq::send_multipart(*ctx->send_socket, msgs); + auto msgs = topohelper_to_messages(topo_helper, n_world); + socket_to_close->set(zmq::sockopt::linger, 3500); + zmq::send_multipart(*socket_to_close, msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); - if(!dev_info_set){ - delete[] dev_info_ptr; - } return -1; } } - - zmq::socket_t * socket_to_close = nullptr; - if (n_layer_window[my_rank] > 0) { - // reconstruct socket to the next valid rank - std::string next_ip; - auto current_rank = my_rank; - - while (next_rank != my_rank) { - if (n_layer_window[next_rank] > 0) { - next_ip = dev_info_ptr[current_rank].next_ip; - break; - } - next_rank = (next_rank + 1) % n_world; - current_rank = (current_rank + 1) % n_world; - } - - if (!next_ip.empty()) { - if ((my_rank + 1) % n_world != next_rank) { - socket_to_close = ctx->send_socket; - ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); - std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); - ctx->send_socket->connect(send_endp); - ctx->next_node_ip = next_ip; - ctx->cparams.original_next_rank = next_rank; - } - - if (next_rank != 0) { - try { - auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); - zmq::send_multipart(*ctx->send_socket, msgs); - } catch (const zmq::error_t &e) { - LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); - if(!dev_info_set){ - delete[] dev_info_ptr; - } - return -1; - } - } - } else { - // only one node - ctx->next_node_ip = ""; + + // notify next connect node + if(!ctx->next_node_ip.empty() && is_not_exit){ + GGML_ASSERT(ctx->send_socket != nullptr); + try { + auto msgs = topohelper_to_messages(topo_helper, n_world); + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t& e) { + LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); + return -1; } } - - if (!dev_info_set) { - delete[] dev_info_ptr; + + if(n_layer_window[my_rank] > 0){ + *node_type = NodeType::NODE_TYPE_WORKER; + }else if (topo_helper[my_rank].is_fowarder == 1){ + *node_type = NodeType::NODE_TYPE_FOWARDER; + }else{ + *node_type = NodeType::NODE_TYPE_EXIT; } + + if(ctx->send_socket != nullptr && *node_type!=NodeType::NODE_TYPE_EXIT){ + // recv the whole view of all nodes + std::vector msgs; + if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { + return -1; + } + GGML_ASSERT(msgs.size() == n_world); + for (size_t i = 0; i < msgs.size(); i++) { + topo_helper[i].deserialize((char *)msgs[i].data()); + } + // broadcast the whole view + if(next_connect_rank!=0){ + try { + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t& e) { + LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); + return -1; + } + } + } + for(size_t i = 0; i < n_world; i++) { + is_fowarder[i] = topo_helper[i].is_fowarder; + } + if(socket_to_close != nullptr){ socket_to_close->close(); delete socket_to_close; } + delete [] topo_helper; + return 0; +} +LLAMA_API int llama_foward_messages(llama_context *ctx) { + zmq::message_t message; + bool more = true; + + while (more) { + ctx->recv_socket->recv(message, zmq::recv_flags::none); + size_t more_size = sizeof(more); + ctx->recv_socket->getsockopt(ZMQ_RCVMORE, &more, &more_size); + + ctx->send_socket->send(message, + more ? zmq::send_flags::sndmore : zmq::send_flags::none); + } return 0; }