mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 05:59:03 +00:00
support reconnection
This commit is contained in:
parent
e50b3aa473
commit
d1b97f798e
5 changed files with 226 additions and 69 deletions
|
@ -1717,6 +1717,8 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
}
|
||||
|
||||
// sychronize device profile to the master node
|
||||
NodeType node_type;
|
||||
char is_fowarder[32] = {0};
|
||||
if (my_rank == 0) {
|
||||
if (auto_schedule) {
|
||||
std::vector<device_info> dev_info_set(n_world);
|
||||
|
@ -1743,14 +1745,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
if (auto_schedule){
|
||||
llama_send_device_info(lctx, &dev_info);
|
||||
llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers);
|
||||
llama_rebuild_topo (lctx, n_layer_window, nullptr);
|
||||
llama_rebuild_topo (lctx, n_layer_window, nullptr, &node_type, is_fowarder);
|
||||
} else {
|
||||
llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers);
|
||||
}
|
||||
}
|
||||
|
||||
// if this is a weak device, then exit
|
||||
if (n_layer_window[my_rank] <= 0) {
|
||||
if (node_type == NodeType::NODE_TYPE_EXIT) {
|
||||
LOG_INF("No layer is assigned to me, exit.\n");
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
|
@ -1762,7 +1764,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
std::vector<uint32_t> n_layer_window_temp = {n_layer_window[0]}, n_gpu_layers_temp = {n_gpu_layers[0]};
|
||||
|
||||
for (uint32_t i = 1; i < n_world; i++) {
|
||||
if (n_layer_window[i] <= 0) {
|
||||
if (n_layer_window[i] <= 0 && is_fowarder[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (i <= my_rank) {
|
||||
|
@ -1794,7 +1796,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
n_world = update_n_world;
|
||||
|
||||
llama_update_context_with_rankworld(lctx, update_rank, update_n_world);
|
||||
|
||||
|
||||
if(node_type == NodeType::NODE_TYPE_EXIT){
|
||||
//just foward
|
||||
while (true) {
|
||||
llama_foward_messages(lctx);
|
||||
}
|
||||
}
|
||||
|
||||
// update n_layer_window and n_gpu_layers
|
||||
std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window);
|
||||
std::copy(std::begin(n_layer_window), std::end(n_layer_window), cparams.n_layer_window);
|
||||
|
|
|
@ -2621,7 +2621,7 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) {
|
|||
return total_size;
|
||||
}
|
||||
|
||||
void deserialize(const char * buffer, struct device_info * dev_info) {
|
||||
size_t deserialize(const char * buffer, struct device_info * dev_info) {
|
||||
const char * ptr = buffer;
|
||||
|
||||
// rank
|
||||
|
@ -2821,6 +2821,32 @@ void deserialize(const char * buffer, struct device_info * dev_info) {
|
|||
ptr += sizeof(float);
|
||||
|
||||
memcpy(&dev_info->gpu_props.cuda_mem_cpy_delay, ptr, sizeof(float));
|
||||
ptr += sizeof(float);
|
||||
|
||||
// no need to synchronize model flops and model params
|
||||
}
|
||||
return ptr - buffer;
|
||||
}
|
||||
|
||||
void TopoRebuildHelperInfo::deserialize(const char *buffer) {
|
||||
size_t buffer_size = ::deserialize(buffer, &dev_info);
|
||||
if (buffer_size == 0) {
|
||||
LOG_ERR("%s: failed to deserialize device info\n", __func__);
|
||||
return;
|
||||
}
|
||||
memcpy(&is_fowarder, buffer + buffer_size, 1);
|
||||
}
|
||||
|
||||
size_t TopoRebuildHelperInfo::serialize(char **buffer) const{
|
||||
size_t buffer_size = ::serialize(&dev_info, buffer);
|
||||
char* buffer_ = (char*)malloc(buffer_size+1);
|
||||
if (buffer_ == NULL) {
|
||||
LOG_ERR("%s: failed to allocate %zu bytes for device info serialization\n",
|
||||
__func__, buffer_size);
|
||||
return 0;
|
||||
}
|
||||
memcpy(buffer_, *buffer, buffer_size);
|
||||
memcpy(buffer_ + buffer_size, &is_fowarder, 1);
|
||||
free(*buffer);
|
||||
*buffer = buffer_;
|
||||
return buffer_size + 1;
|
||||
}
|
||||
|
|
|
@ -346,6 +346,18 @@ struct device_info {
|
|||
model_bytes() {}
|
||||
};
|
||||
|
||||
struct TopoRebuildHelperInfo{
|
||||
struct device_info dev_info;
|
||||
char is_fowarder;
|
||||
|
||||
TopoRebuildHelperInfo():
|
||||
dev_info(),
|
||||
is_fowarder(0){}
|
||||
|
||||
void deserialize(const char * buffer);
|
||||
size_t serialize(char ** buffer) const;
|
||||
};
|
||||
|
||||
enum profiler_backend_type {
|
||||
PROFILER_BACKEND_TYPE_CPU = 0,
|
||||
PROFILER_BACKEND_TYPE_METAL = 1,
|
||||
|
@ -389,6 +401,6 @@ int device_has_blas (void);
|
|||
int device_has_sycl (void);
|
||||
|
||||
size_t serialize (const struct device_info * dev_info, char ** buffer);
|
||||
void deserialize(const char * buffer, struct device_info * dev_info);
|
||||
size_t deserialize(const char * buffer, struct device_info * dev_info);
|
||||
|
||||
#endif // PROFILER_H
|
||||
|
|
|
@ -448,6 +448,12 @@ extern "C" {
|
|||
struct llama_model_params params);
|
||||
|
||||
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||
|
||||
enum NodeType{
|
||||
NODE_TYPE_WORKER,
|
||||
NODE_TYPE_FOWARDER,
|
||||
NODE_TYPE_EXIT,
|
||||
};
|
||||
|
||||
LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank);
|
||||
LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg);
|
||||
|
@ -455,7 +461,12 @@ extern "C" {
|
|||
LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info);
|
||||
LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args);
|
||||
LLAMA_API int llama_bcast_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers);
|
||||
LLAMA_API int llama_rebuild_topo (struct llama_context * ctx, uint32_t * n_layer_window, struct device_info * dev_info_set);
|
||||
LLAMA_API int llama_rebuild_topo (struct llama_context * ctx,
|
||||
uint32_t * n_layer_window,
|
||||
struct device_info * dev_info_set,
|
||||
NodeType* node_type,
|
||||
char * is_fowarder);
|
||||
LLAMA_API int llama_foward_messages (struct llama_context * ctx);
|
||||
LLAMA_API int llama_recv_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers);
|
||||
|
||||
LLAMA_API int llm_load_tensors(
|
||||
|
|
221
src/llama.cpp
221
src/llama.cpp
|
@ -173,12 +173,12 @@ static void zeros(std::ofstream & file, size_t n) {
|
|||
}
|
||||
|
||||
// zmq helpers
|
||||
static std::vector<zmq::message_t> dev_infos_to_messages(const device_info* infos,
|
||||
uint32_t n_world){
|
||||
static std::vector<zmq::message_t> topohelper_to_messages(const TopoRebuildHelperInfo* infos,
|
||||
uint32_t n_world){
|
||||
std::vector<zmq::message_t> res;
|
||||
for (uint32_t i = 0; i < n_world; ++i) {
|
||||
char * buffer = nullptr;
|
||||
size_t buffer_size = serialize(&infos[i], &buffer);
|
||||
size_t buffer_size = infos[i].serialize(&buffer);
|
||||
res.emplace_back(buffer, buffer_size);
|
||||
free(buffer);
|
||||
}
|
||||
|
@ -20428,6 +20428,39 @@ static uint32_t map_rank_to_port(uint32_t rank, uint32_t data_port) {
|
|||
return data_port + rank;
|
||||
}
|
||||
|
||||
static std::string try_connect(llama_context *ctx, uint32_t rank, TopoRebuildHelperInfo* infos, uint32_t n_world, zmq::socket_t** socket){
|
||||
auto prv_rank = (rank - 1 + n_world) % n_world;
|
||||
std::string ip = infos[prv_rank].dev_info.next_ip;
|
||||
std::string send_endp = "tcp://" + ip + ":" + std::to_string(map_rank_to_port(rank, ctx->data_port));
|
||||
*socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||
int events = 0;
|
||||
try {
|
||||
(*socket)->set(zmq::sockopt::linger, 0);
|
||||
(*socket)->set(zmq::sockopt::sndtimeo, 500);
|
||||
|
||||
(*socket)->connect(send_endp);
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
|
||||
size_t events_size = sizeof(events);
|
||||
(*socket)->getsockopt(ZMQ_EVENTS, &events, &events_size);
|
||||
|
||||
} catch (const zmq::error_t& e) {
|
||||
delete *socket;
|
||||
*socket = nullptr;
|
||||
return "";
|
||||
}
|
||||
|
||||
if((events & ZMQ_POLLOUT) != 0){
|
||||
return ip;
|
||||
}else{
|
||||
delete *socket;
|
||||
*socket = nullptr;
|
||||
return "";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t my_rank) {
|
||||
if (n_world == 1) {
|
||||
return;
|
||||
|
@ -20602,95 +20635,161 @@ int llama_bcast_layer_setup(struct llama_context * ctx, uint32_t * n_layer_windo
|
|||
return 0;
|
||||
}
|
||||
|
||||
int llama_rebuild_topo(llama_context * ctx, uint32_t * n_layer_window, device_info * dev_info_set) {
|
||||
int llama_rebuild_topo(llama_context * ctx,
|
||||
uint32_t * n_layer_window,
|
||||
device_info * dev_info_set,
|
||||
NodeType * node_type,
|
||||
char * is_fowarder) {
|
||||
uint32_t n_world = ctx->cparams.n_world;
|
||||
uint32_t my_rank = ctx->cparams.rank;
|
||||
device_info * dev_info_ptr = nullptr;
|
||||
TopoRebuildHelperInfo* topo_helper = new TopoRebuildHelperInfo[n_world];
|
||||
|
||||
if (dev_info_set == nullptr) {
|
||||
if (dev_info_set == nullptr){
|
||||
// for rank!=0, recv all devices info
|
||||
std::vector<zmq::message_t> msgs;
|
||||
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) {
|
||||
return -1;
|
||||
}
|
||||
dev_info_ptr = new device_info[n_world];
|
||||
for (size_t i = 0; i < msgs.size(); i++) {
|
||||
deserialize((const char *)msgs[i].data(), &dev_info_ptr[i]);
|
||||
topo_helper[i].deserialize((char *)msgs[i].data());
|
||||
}
|
||||
GGML_ASSERT(msgs.size() == n_world);
|
||||
} else {
|
||||
dev_info_ptr = dev_info_set;
|
||||
for (size_t i = 0; i < n_world; i++) {
|
||||
topo_helper[i].dev_info = dev_info_set[i];
|
||||
topo_helper[i].is_fowarder = 0;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr);
|
||||
|
||||
// notify next rank
|
||||
auto next_rank = (my_rank + 1) % n_world;
|
||||
if (n_layer_window[next_rank] <= 0 && next_rank != 0) {
|
||||
auto next_connect_rank = (my_rank + 1) % n_world;
|
||||
zmq::socket_t* socket_to_close = nullptr;
|
||||
bool is_not_exit = n_layer_window[my_rank] > 0 || topo_helper[my_rank].is_fowarder == 1;
|
||||
if (is_not_exit){
|
||||
// reconstruct socket to the next valid rank
|
||||
auto current_rank = my_rank;
|
||||
std::vector<uint32_t> nodes;
|
||||
auto next_rank_ = next_rank;
|
||||
while (next_rank_ != my_rank) {
|
||||
nodes.push_back(next_rank_);
|
||||
if (n_layer_window[next_rank_] > 0) {
|
||||
break;
|
||||
}
|
||||
next_rank_ = (next_rank_ + 1) % n_world;
|
||||
current_rank = (current_rank + 1) % n_world;
|
||||
}
|
||||
if (next_rank_ == my_rank) {
|
||||
// only one node
|
||||
ctx->next_node_ip = "";
|
||||
socket_to_close = ctx->send_socket;
|
||||
ctx->send_socket = nullptr;
|
||||
} else {
|
||||
// iterate node reverse
|
||||
zmq::socket_t* socket = nullptr;
|
||||
std::string ip;
|
||||
for (int i = nodes.size() - 1; i > 0; --i) {
|
||||
auto rank = nodes[i];
|
||||
ip = try_connect(ctx, rank, topo_helper, n_world, &socket);
|
||||
if(!ip.empty()){
|
||||
topo_helper[rank].is_fowarder = 1;
|
||||
next_connect_rank = rank;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(next_connect_rank != next_rank){
|
||||
// reset socket
|
||||
GGML_ASSERT(socket != nullptr);
|
||||
GGML_ASSERT(!ip.empty());
|
||||
socket_to_close = ctx->send_socket;
|
||||
ctx->send_socket = socket;
|
||||
ctx->next_node_ip = ip;
|
||||
ctx->cparams.original_next_rank = next_connect_rank;
|
||||
}
|
||||
}
|
||||
}else if(n_layer_window[next_rank] <= 0 && topo_helper[my_rank].is_fowarder == 0){
|
||||
socket_to_close = ctx->send_socket;
|
||||
}
|
||||
|
||||
// notify next exiting node
|
||||
if (socket_to_close != nullptr) {
|
||||
GGML_ASSERT(n_layer_window[next_rank] <= 0 && topo_helper[next_rank].is_fowarder == 0);
|
||||
try {
|
||||
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
|
||||
ctx->send_socket->set(zmq::sockopt::linger, 3500);
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
auto msgs = topohelper_to_messages(topo_helper, n_world);
|
||||
socket_to_close->set(zmq::sockopt::linger, 3500);
|
||||
zmq::send_multipart(*socket_to_close, msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
|
||||
if(!dev_info_set){
|
||||
delete[] dev_info_ptr;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
zmq::socket_t * socket_to_close = nullptr;
|
||||
if (n_layer_window[my_rank] > 0) {
|
||||
// reconstruct socket to the next valid rank
|
||||
std::string next_ip;
|
||||
auto current_rank = my_rank;
|
||||
|
||||
while (next_rank != my_rank) {
|
||||
if (n_layer_window[next_rank] > 0) {
|
||||
next_ip = dev_info_ptr[current_rank].next_ip;
|
||||
break;
|
||||
}
|
||||
next_rank = (next_rank + 1) % n_world;
|
||||
current_rank = (current_rank + 1) % n_world;
|
||||
}
|
||||
|
||||
if (!next_ip.empty()) {
|
||||
if ((my_rank + 1) % n_world != next_rank) {
|
||||
socket_to_close = ctx->send_socket;
|
||||
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||
std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
|
||||
ctx->send_socket->connect(send_endp);
|
||||
ctx->next_node_ip = next_ip;
|
||||
ctx->cparams.original_next_rank = next_rank;
|
||||
}
|
||||
|
||||
if (next_rank != 0) {
|
||||
try {
|
||||
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t &e) {
|
||||
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
|
||||
if(!dev_info_set){
|
||||
delete[] dev_info_ptr;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// only one node
|
||||
ctx->next_node_ip = "";
|
||||
|
||||
// notify next connect node
|
||||
if(!ctx->next_node_ip.empty() && is_not_exit){
|
||||
GGML_ASSERT(ctx->send_socket != nullptr);
|
||||
try {
|
||||
auto msgs = topohelper_to_messages(topo_helper, n_world);
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!dev_info_set) {
|
||||
delete[] dev_info_ptr;
|
||||
|
||||
if(n_layer_window[my_rank] > 0){
|
||||
*node_type = NodeType::NODE_TYPE_WORKER;
|
||||
}else if (topo_helper[my_rank].is_fowarder == 1){
|
||||
*node_type = NodeType::NODE_TYPE_FOWARDER;
|
||||
}else{
|
||||
*node_type = NodeType::NODE_TYPE_EXIT;
|
||||
}
|
||||
|
||||
if(ctx->send_socket != nullptr && *node_type!=NodeType::NODE_TYPE_EXIT){
|
||||
// recv the whole view of all nodes
|
||||
std::vector<zmq::message_t> msgs;
|
||||
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) {
|
||||
return -1;
|
||||
}
|
||||
GGML_ASSERT(msgs.size() == n_world);
|
||||
for (size_t i = 0; i < msgs.size(); i++) {
|
||||
topo_helper[i].deserialize((char *)msgs[i].data());
|
||||
}
|
||||
// broadcast the whole view
|
||||
if(next_connect_rank!=0){
|
||||
try {
|
||||
zmq::send_multipart(*ctx->send_socket, msgs);
|
||||
} catch (const zmq::error_t& e) {
|
||||
LLAMA_LOG_INFO("Failed to send data: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < n_world; i++) {
|
||||
is_fowarder[i] = topo_helper[i].is_fowarder;
|
||||
}
|
||||
|
||||
|
||||
if(socket_to_close != nullptr){
|
||||
socket_to_close->close();
|
||||
delete socket_to_close;
|
||||
}
|
||||
delete [] topo_helper;
|
||||
return 0;
|
||||
}
|
||||
|
||||
LLAMA_API int llama_foward_messages(llama_context *ctx) {
|
||||
zmq::message_t message;
|
||||
bool more = true;
|
||||
|
||||
while (more) {
|
||||
ctx->recv_socket->recv(message, zmq::recv_flags::none);
|
||||
size_t more_size = sizeof(more);
|
||||
ctx->recv_socket->getsockopt(ZMQ_RCVMORE, &more, &more_size);
|
||||
|
||||
ctx->send_socket->send(message,
|
||||
more ? zmq::send_flags::sndmore : zmq::send_flags::none);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue