diff --git a/Makefile b/Makefile index 8d9f7410..39ae0b9f 100644 --- a/Makefile +++ b/Makefile @@ -952,7 +952,8 @@ OBJ_LLAMA = \ src/llama-grammar.o \ src/llama-sampling.o \ src/unicode.o \ - src/unicode-data.o + src/unicode-data.o \ + src/network-utils.o \ OBJ_COMMON = \ common/profiler.o \ @@ -1141,6 +1142,11 @@ src/unicode-data.o: \ src/unicode-data.cpp \ src/unicode-data.h $(CXX) $(CXXFLAGS) -c $< -o $@ + +src/network-utils.o: \ + src/network-utils.cpp \ + src/network-utils.h + $(CXX) $(CXXFLAGS) -c $< -o $@ src/llama.o: \ src/llama.cpp \ @@ -1149,6 +1155,7 @@ src/llama.o: \ src/llama-grammar.h \ src/llama-sampling.h \ src/unicode.h \ + src/network-utils.h \ include/llama.h \ ggml/include/ggml-cuda.h \ ggml/include/ggml-metal.h \ diff --git a/common/arg.cpp b/common/arg.cpp index 47d3c5e6..e282c80d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -675,6 +675,20 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.rank = value; } ).set_env("LLAMA_ARG_RANK")); + add_opt(llama_arg( + {"--data-port"}, "N", + format("data port for distributed inference (default: %d)", params.data_port), + [](gpt_params & params, int value) { + params.data_port = value; + } + ).set_env("LLAMA_ARG_DATA_PORT")); + add_opt(llama_arg( + {"--signal-port"}, "N", + format("signal port for distributed inference (default: %d)", params.signal_port), + [](gpt_params & params, int value) { + params.signal_port = value; + } + ).set_env("LLAMA_ARG_SIGNAL_PORT")); add_opt(llama_arg( {"-lw", "--layer-window", "--n-layer-window"}, "N", format("number of layers to process in each compute (e.g., 16,16)"), diff --git a/common/common.cpp b/common/common.cpp index 79a60bb2..39b95d32 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -1681,6 +1682,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { cparams.n_layer_window[0] = n_layers; mparams.n_layer_window[0] = n_layers; llama_context_n_layer_window(lctx)[0] = n_layers; + llama_update_context_with_rankworld(lctx, 0, 1, 0, 1); #if defined(GGML_USE_METAL) || defined(GGML_USE_CUDA) params.n_gpu_layers = std::min((int32_t)n_layers, params.n_gpu_layers); @@ -1722,6 +1724,8 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } // sychronize device profile to the master node + NodeType node_type = NodeType::NODE_TYPE_WORKER; + char is_forwarder[32] = {0}; if (my_rank == 0) { if (auto_schedule) { std::vector dev_info_set(n_world); @@ -1738,7 +1742,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { return iparams; } llama_bcast_layer_setup(lctx, n_layer_window, n_gpu_layers); - llama_rebuild_topo(lctx, n_layer_window, dev_info_set.data()); + llama_rebuild_topo(lctx, n_layer_window, dev_info_set.data(), &node_type, is_forwarder); } else { // use the user-defined n_layer_window std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), n_layer_window); @@ -1748,14 +1752,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (auto_schedule){ llama_send_device_info(lctx, &dev_info); llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); - llama_rebuild_topo (lctx, n_layer_window, nullptr); + llama_rebuild_topo (lctx, n_layer_window, nullptr, &node_type, is_forwarder); } else { llama_recv_layer_setup(lctx, n_layer_window, n_gpu_layers); } } // if this is a weak device, then exit - if (n_layer_window[my_rank] <= 0) { + if (node_type == NodeType::NODE_TYPE_EXIT) { LOG_INF("No layer is assigned to me, exit.\n"); llama_free(lctx); llama_free_model(model); @@ -1764,10 +1768,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { // update my rank and n_world uint32_t update_rank = 0, update_n_world = 1; + uint32_t worker_rank = 0, n_worker = 1; std::vector n_layer_window_temp = {n_layer_window[0]}, n_gpu_layers_temp = {n_gpu_layers[0]}; for (uint32_t i = 1; i < n_world; i++) { - if (n_layer_window[i] <= 0) { + if (n_layer_window[i] <= 0 && is_forwarder[i] == 0) { continue; } if (i <= my_rank) { @@ -1776,6 +1781,13 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { update_n_world++; n_layer_window_temp.push_back(n_layer_window[i]); n_gpu_layers_temp.push_back(n_gpu_layers[i]); + + if (n_layer_window[i] > 0) { + if (i <= my_rank) { + worker_rank++; + } + n_worker++; + } } memset(n_layer_window, 0, n_world * sizeof(uint32_t)); @@ -1798,8 +1810,26 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { params.n_world = update_n_world; n_world = update_n_world; - llama_update_context_with_rankworld(lctx, update_rank, update_n_world); - + llama_update_context_with_rankworld(lctx, update_rank, update_n_world, worker_rank, n_worker); + + if (node_type == NodeType::NODE_TYPE_FORWARDER) { + //just forward + LOG_INF("No layer is assigned to me, and I serve as a network proxy.\n"); + std::atomic should_exit{false}; + auto t = std::thread([lctx, &should_exit]() { + while(!should_exit) { + llama_forward_messages(lctx); + } + }); + char * stop_signal = nullptr; + llama_free_sockets(lctx, &stop_signal); // this will block until receive stop signal + + should_exit = true; + t.join(); + + exit(0); + } + // update n_layer_window and n_gpu_layers std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window); std::copy(std::begin(n_layer_window), std::end(n_layer_window), cparams.n_layer_window); @@ -2004,6 +2034,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param } cparams.master_ip = new char[params.master_ip.length() + 1]; std::strcpy(cparams.master_ip, params.master_ip.c_str()); + cparams.data_port = params.data_port; + cparams.signal_port = params.signal_port; if (cparams.next_node_ip != nullptr) { delete[] cparams.next_node_ip; diff --git a/common/common.h b/common/common.h index 0a679213..c6ffe136 100644 --- a/common/common.h +++ b/common/common.h @@ -145,8 +145,10 @@ struct gpt_params { int32_t n_world = 1; // number of devices to use int32_t rank = 0; // my rank for distributed inference uint32_t n_layer_window[32] = {0}; // layer window size on each node - std::string master_ip = "localhost"; // ip address of the master node - std::string next_node_ip = "localhost"; // ip address of my next node + std::string master_ip = "127.0.0.1"; // ip address of the master node + std::string next_node_ip = "127.0.0.1"; // ip address of my next node + uint32_t data_port = 9000; // data port for distributed inference + uint32_t signal_port = 10000; // signal port for distributed inference bool prefetch = false; // prefetch layer weights bool keep_out_in_metal = true; // whether to keep output weights in metal memory, true by default bool force = false; // force to start prefetching after computation diff --git a/common/profiler.cpp b/common/profiler.cpp index 0da13824..18fe795d 100644 --- a/common/profiler.cpp +++ b/common/profiler.cpp @@ -2621,7 +2621,7 @@ size_t serialize(const struct device_info * dev_info, char ** buffer) { return total_size; } -void deserialize(const char * buffer, struct device_info * dev_info) { +size_t deserialize(const char * buffer, struct device_info * dev_info) { const char * ptr = buffer; // rank @@ -2821,6 +2821,32 @@ void deserialize(const char * buffer, struct device_info * dev_info) { ptr += sizeof(float); memcpy(&dev_info->gpu_props.cuda_mem_cpy_delay, ptr, sizeof(float)); + ptr += sizeof(float); // no need to synchronize model flops and model params -} \ No newline at end of file + return ptr - buffer; +} + +void TopoRebuildHelperInfo::deserialize(const char *buffer) { + size_t buffer_size = ::deserialize(buffer, &dev_info); + if (buffer_size == 0) { + LOG_ERR("%s: failed to deserialize device info\n", __func__); + return; + } + memcpy(&is_forwarder, buffer + buffer_size, 1); +} + +size_t TopoRebuildHelperInfo::serialize(char **buffer) const{ + size_t buffer_size = ::serialize(&dev_info, buffer); + char* buffer_ = (char*)malloc(buffer_size+1); + if (buffer_ == NULL) { + LOG_ERR("%s: failed to allocate %zu bytes for device info serialization\n", + __func__, buffer_size); + return 0; + } + memcpy(buffer_, *buffer, buffer_size); + memcpy(buffer_ + buffer_size, &is_forwarder, 1); + free(*buffer); + *buffer = buffer_; + return buffer_size + 1; +} diff --git a/common/profiler.h b/common/profiler.h index 06741d6c..ff69a454 100644 --- a/common/profiler.h +++ b/common/profiler.h @@ -346,6 +346,18 @@ struct device_info { model_bytes() {} }; +struct TopoRebuildHelperInfo{ + struct device_info dev_info; + char is_forwarder; + + TopoRebuildHelperInfo(): + dev_info(), + is_forwarder(0){} + + void deserialize(const char * buffer); + size_t serialize(char ** buffer) const; +}; + enum profiler_backend_type { PROFILER_BACKEND_TYPE_CPU = 0, PROFILER_BACKEND_TYPE_METAL = 1, @@ -389,6 +401,6 @@ int device_has_blas (void); int device_has_sycl (void); size_t serialize (const struct device_info * dev_info, char ** buffer); -void deserialize(const char * buffer, struct device_info * dev_info); +size_t deserialize(const char * buffer, struct device_info * dev_info); #endif // PROFILER_H diff --git a/include/llama.h b/include/llama.h index 4c39b063..3c220562 100644 --- a/include/llama.h +++ b/include/llama.h @@ -330,6 +330,8 @@ extern "C" { bool keep_out_in_metal; // whether to keep output weights in metal memory char * master_ip; // ip address of the master node char * next_node_ip; // ip address of the next node + uint32_t data_port; // data port for distributed inference + uint32_t signal_port; // signal port for distributed inference uint32_t n_ctx; // text context, 0 = from model uint32_t n_predict; // number of tokens to predict uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode @@ -448,6 +450,12 @@ extern "C" { struct llama_model_params params); LLAMA_API void llama_free_model(struct llama_model * model); + + enum NodeType{ + NODE_TYPE_WORKER, + NODE_TYPE_FORWARDER, + NODE_TYPE_EXIT, + }; LLAMA_API void llama_init_sockets (struct llama_context * ctx, uint32_t n_world, uint32_t my_rank); LLAMA_API void llama_free_sockets (struct llama_context * ctx, char ** msg); @@ -455,7 +463,12 @@ extern "C" { LLAMA_API int llama_send_device_info (struct llama_context * ctx, struct device_info * dev_info); LLAMA_API int llama_bcast_startup_args(struct llama_context * ctx, uint32_t rank, struct startup_args * args); LLAMA_API int llama_bcast_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers); - LLAMA_API int llama_rebuild_topo (struct llama_context * ctx, uint32_t * n_layer_window, struct device_info * dev_info_set); + LLAMA_API int llama_rebuild_topo (struct llama_context * ctx, + uint32_t * n_layer_window, + struct device_info * desv_info_set, + NodeType* node_type, + char * is_forwarder); + LLAMA_API int llama_forward_messages (struct llama_context * ctx); LLAMA_API int llama_recv_layer_setup (struct llama_context * ctx, uint32_t * n_layer_window, uint32_t * n_gpu_layers); LLAMA_API int llm_load_tensors( @@ -466,7 +479,9 @@ extern "C" { LLAMA_API void llama_update_context_with_rankworld( struct llama_context * ctx, uint32_t rank, - uint32_t n_world); + uint32_t n_world, + uint32_t worker_rank, + uint32_t n_worker); LLAMA_API struct llama_context * llama_new_context_with_model( struct llama_model * model, diff --git a/src/llama.cpp b/src/llama.cpp index 7cd74983..8b5af567 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -11,6 +11,7 @@ #include "ggml-backend.h" #include "profiler.h" +#include "network-utils.h" #ifdef GGML_USE_RPC # include "ggml-rpc.h" @@ -173,12 +174,12 @@ static void zeros(std::ofstream & file, size_t n) { } // zmq helpers -static std::vector dev_infos_to_messages(const device_info* infos, - uint32_t n_world){ +static std::vector topohelper_to_messages(const TopoRebuildHelperInfo* infos, + uint32_t n_world){ std::vector res; for (uint32_t i = 0; i < n_world; ++i) { char * buffer = nullptr; - size_t buffer_size = serialize(&infos[i], &buffer); + size_t buffer_size = infos[i].serialize(&buffer); res.emplace_back(buffer, buffer_size); free(buffer); } @@ -2596,6 +2597,9 @@ static_assert(std::is_trivially_copyable::value, "llama_hparams m struct llama_cparams { uint32_t n_world; uint32_t rank; + NodeType node_type; + uint32_t n_worker; + uint32_t worker_rank; uint32_t original_next_rank; // original rank of the next node uint32_t n_layer_window[32]; bool prefetch; @@ -3434,8 +3438,8 @@ struct llama_context { struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] // sockets - std::string master_ip = "localhost"; - std::string next_node_ip = "localhost"; + std::string master_ip = "127.0.0.1"; + std::string next_node_ip = "127.0.0.1"; uint32_t data_port = 9000; uint32_t signal_port = 10000; zmq::context_t * sock_context = nullptr; @@ -18221,6 +18225,9 @@ static int llama_decode_internal( const uint32_t n_world = cparams.n_world; const uint32_t my_rank = cparams.rank; + const uint32_t n_worker = cparams.n_worker; + const uint32_t worker_rank = cparams.worker_rank; + lctx.is_encoding = false; const uint32_t n_tokens_all = batch_all.n_tokens; if (my_rank != 0) { @@ -18276,7 +18283,7 @@ static int llama_decode_internal( sync_meta meta; meta.n_ctx = cparams.n_ctx; - bool is_last_dev = (my_rank == n_world - 1); + bool is_last_dev = (worker_rank == n_worker - 1); if (my_rank != 0) { if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) { @@ -20259,6 +20266,8 @@ struct llama_context_params llama_context_default_params() { /*.keep_out_in_metal =*/ true, /*.master_ip =*/ nullptr, /*.next_node_ip =*/ nullptr, + /*.data_port =*/ 9000, + /*.signal_port =*/ 10000, /*.n_ctx =*/ 512, /*.n_predict =*/ 512, /*.n_batch =*/ 2048, @@ -20445,6 +20454,27 @@ static uint32_t map_rank_to_port(uint32_t rank, uint32_t data_port) { return data_port + rank; } +static std::string try_connect(llama_context * ctx, uint32_t rank, TopoRebuildHelperInfo * infos, uint32_t n_world, zmq::socket_t ** socket){ + auto prev_rank = (rank - 1 + n_world) % n_world; + std::string ip = infos[prev_rank].dev_info.next_ip; + auto port = map_rank_to_port(rank, ctx->data_port); + + if (!is_port_open(ip, port)) { + *socket = nullptr; + return ""; + } + std::string send_endp = "tcp://" + ip + ":" + std::to_string(port); + *socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); + try { + (*socket)->connect(send_endp); + } catch (const zmq::error_t& e) { + delete *socket; + *socket = nullptr; + return ""; + } + return ip; +} + void llama_init_sockets(struct llama_context * ctx, uint32_t n_world, uint32_t my_rank) { if (n_world == 1) { return; @@ -20619,95 +20649,165 @@ int llama_bcast_layer_setup(struct llama_context * ctx, uint32_t * n_layer_windo return 0; } -int llama_rebuild_topo(llama_context * ctx, uint32_t * n_layer_window, device_info * dev_info_set) { +int llama_rebuild_topo(llama_context * ctx, + uint32_t * n_layer_window, + device_info * dev_info_set, + NodeType * node_type, + char * is_forwarder) { uint32_t n_world = ctx->cparams.n_world; uint32_t my_rank = ctx->cparams.rank; - device_info * dev_info_ptr = nullptr; + TopoRebuildHelperInfo* topo_helper = new TopoRebuildHelperInfo[n_world]; - if (dev_info_set == nullptr) { + if (dev_info_set == nullptr){ + // for rank!=0, recv all devices info std::vector msgs; if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { return -1; } - dev_info_ptr = new device_info[n_world]; for (size_t i = 0; i < msgs.size(); i++) { - deserialize((const char *)msgs[i].data(), &dev_info_ptr[i]); + topo_helper[i].deserialize((char *)msgs[i].data()); } GGML_ASSERT(msgs.size() == n_world); } else { - dev_info_ptr = dev_info_set; + for (size_t i = 0; i < n_world; i++) { + topo_helper[i].dev_info = dev_info_set[i]; + topo_helper[i].is_forwarder = 0; + } } GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr); - // notify next rank auto next_rank = (my_rank + 1) % n_world; - if (n_layer_window[next_rank] <= 0 && next_rank != 0) { + auto next_connect_rank = (my_rank + 1) % n_world; + zmq::socket_t* socket_to_close = nullptr; + bool is_not_exit = n_layer_window[my_rank] > 0 || topo_helper[my_rank].is_forwarder == 1; + if (is_not_exit) { + // reconstruct socket to the next valid rank + auto current_rank = my_rank; + std::vector nodes; + auto next_rank_ = next_rank; + while (next_rank_ != my_rank) { + nodes.push_back(next_rank_); + if (n_layer_window[next_rank_] > 0) { + break; + } + next_rank_ = (next_rank_ + 1) % n_world; + current_rank = (current_rank + 1) % n_world; + } + if (next_rank_ == my_rank) { + // only one node + ctx->next_node_ip = ""; + socket_to_close = ctx->send_socket; + ctx->send_socket = nullptr; + } else { + // iterate node reverse + zmq::socket_t* socket = nullptr; + std::string ip; + for (int i = nodes.size() - 1; i > 0; --i) { + auto rank = nodes[i]; + ip = try_connect(ctx, rank, topo_helper, n_world, &socket); + if (!ip.empty()) { + next_connect_rank = rank; + break; + } + } + topo_helper[next_connect_rank].is_forwarder = 1; + if (next_connect_rank != next_rank) { + // reset socket + GGML_ASSERT(socket != nullptr); + GGML_ASSERT(!ip.empty()); + socket_to_close = ctx->send_socket; + ctx->send_socket = socket; + ctx->next_node_ip = ip; + ctx->cparams.original_next_rank = next_connect_rank; + } + } + }else if (n_layer_window[next_rank] <= 0 && topo_helper[next_rank].is_forwarder == 0) { + socket_to_close = ctx->send_socket; + } + + // notify next exiting node + if (socket_to_close != nullptr) { + GGML_ASSERT(n_layer_window[next_rank] <= 0 && topo_helper[next_rank].is_forwarder == 0); try { - auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); - ctx->send_socket->set(zmq::sockopt::linger, 3500); - zmq::send_multipart(*ctx->send_socket, msgs); + auto msgs = topohelper_to_messages(topo_helper, n_world); + socket_to_close->set(zmq::sockopt::linger, 3500); + zmq::send_multipart(*socket_to_close, msgs); } catch (const zmq::error_t& e) { LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); - if(!dev_info_set){ - delete[] dev_info_ptr; - } return -1; } } - - zmq::socket_t * socket_to_close = nullptr; + + // notify next connect node + if (!ctx->next_node_ip.empty() && is_not_exit) { + GGML_ASSERT(ctx->send_socket != nullptr); + try { + auto msgs = topohelper_to_messages(topo_helper, n_world); + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t& e) { + LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); + return -1; + } + } + if (n_layer_window[my_rank] > 0) { - // reconstruct socket to the next valid rank - std::string next_ip; - auto current_rank = my_rank; - - while (next_rank != my_rank) { - if (n_layer_window[next_rank] > 0) { - next_ip = dev_info_ptr[current_rank].next_ip; - break; - } - next_rank = (next_rank + 1) % n_world; - current_rank = (current_rank + 1) % n_world; - } - - if (!next_ip.empty()) { - if ((my_rank + 1) % n_world != next_rank) { - socket_to_close = ctx->send_socket; - ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); - std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); - ctx->send_socket->connect(send_endp); - ctx->next_node_ip = next_ip; - ctx->cparams.original_next_rank = next_rank; - } - - if (next_rank != 0) { - try { - auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); - zmq::send_multipart(*ctx->send_socket, msgs); - } catch (const zmq::error_t &e) { - LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); - if(!dev_info_set){ - delete[] dev_info_ptr; - } - return -1; - } - } - } else { - // only one node - ctx->next_node_ip = ""; - } + *node_type = NodeType::NODE_TYPE_WORKER; + } else if (topo_helper[my_rank].is_forwarder == 1) { + *node_type = NodeType::NODE_TYPE_FORWARDER; + } else { + *node_type = NodeType::NODE_TYPE_EXIT; } - if (!dev_info_set) { - delete[] dev_info_ptr; + if (ctx->send_socket != nullptr && *node_type != NodeType::NODE_TYPE_EXIT) { + // recv the whole view of all nodes + std::vector msgs; + if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { + return -1; + } + GGML_ASSERT(msgs.size() == n_world); + for (size_t i = 0; i < msgs.size(); i++) { + topo_helper[i].deserialize((char *)msgs[i].data()); + } + // broadcast the whole view + if (next_connect_rank!=0) { + try { + zmq::send_multipart(*ctx->send_socket, msgs); + } catch (const zmq::error_t& e) { + LLAMA_LOG_INFO("Failed to send data: %s\n", e.what()); + return -1; + } + } } + for (size_t i = 0; i < n_world; i++) { + is_forwarder[i] = topo_helper[i].is_forwarder; + } + ctx->cparams.node_type = *node_type; - if(socket_to_close != nullptr){ + if (socket_to_close != nullptr) { socket_to_close->close(); delete socket_to_close; } + delete [] topo_helper; + return 0; +} +int llama_forward_messages(llama_context *ctx) { + zmq::message_t message; + int more = true; + int timeout_ms = 10; + ctx->recv_socket->setsockopt(ZMQ_RCVTIMEO, &timeout_ms, sizeof(timeout_ms)); + while (more) { + auto recv_result = ctx->recv_socket->recv(message, zmq::recv_flags::none); + if (!recv_result) { + return -1; + } + size_t more_size = sizeof(more); + ctx->recv_socket->getsockopt(ZMQ_RCVMORE, &more, &more_size); + + ctx->send_socket->send(message, + more ? zmq::send_flags::sndmore : zmq::send_flags::none); + } return 0; } @@ -20772,10 +20872,16 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { } } -void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) { +void llama_update_context_with_rankworld(struct llama_context * ctx, + uint32_t rank, + uint32_t n_world, + uint32_t worker_rank, + uint32_t n_worker) { if (ctx) { ctx->cparams.rank = rank; ctx->cparams.n_world = n_world; + ctx->cparams.worker_rank = worker_rank; + ctx->cparams.n_worker = n_worker; } } @@ -20792,6 +20898,8 @@ struct llama_context * llama_new_context_with_model( ctx->master_ip = params.master_ip; ctx->next_node_ip = params.next_node_ip; + ctx->data_port = params.data_port; + ctx->signal_port = params.signal_port; ctx->cparams.n_world = params.n_world; ctx->cparams.rank = params.rank; ctx->cparams.force = params.force; diff --git a/src/network-utils.cpp b/src/network-utils.cpp new file mode 100644 index 00000000..e7fa5ab1 --- /dev/null +++ b/src/network-utils.cpp @@ -0,0 +1,26 @@ +#include "network-utils.h" + +#include +#include +#include +#include + +bool is_port_open(const std::string& ip, uint32_t port, int timeout_sec) { + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) return false; + + struct timeval tv; + tv.tv_sec = timeout_sec; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); + + struct sockaddr_in server; + server.sin_addr.s_addr = inet_addr(ip.c_str()); + server.sin_family = AF_INET; + server.sin_port = htons(port); + + int res = connect(sock, (struct sockaddr*)&server, sizeof(server)); + close(sock); + return res == 0; +} \ No newline at end of file diff --git a/src/network-utils.h b/src/network-utils.h new file mode 100644 index 00000000..7a35475a --- /dev/null +++ b/src/network-utils.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +typedef unsigned int uint32_t; + +bool is_port_open(const std::string& ip, uint32_t port, int timeout_sec = 2); \ No newline at end of file