From df16b1876f316958dbef6e3843f63e1888cb4725 Mon Sep 17 00:00:00 2001 From: DeEMO Date: Fri, 16 May 2025 16:02:25 +0800 Subject: [PATCH] refactor: add zmq helper to generate message Signed-off-by: DeEMO --- src/llama.cpp | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 5c640b90..a0b91edd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -161,6 +161,19 @@ static void zeros(std::ofstream & file, size_t n) { } } +// zmq helpers +static std::vector dev_infos_to_messages(const device_info* infos, + uint32_t n_world){ + std::vector res; + for (uint32_t i = 0; i < n_world; ++i) { + char * buffer = nullptr; + size_t buffer_size = serialize(&infos[i], &buffer); + res.emplace_back(buffer, buffer_size); + free(buffer); + } + return res; +} + LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -20334,10 +20347,10 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, device_info *dev_info_set) { uint32_t n_world = ctx->cparams.n_world; uint32_t my_rank = ctx->cparams.rank; - std::vector msgs; device_info* dev_info_ptr = nullptr; if (dev_info_set == nullptr){ // for rank!=0, recv all devices info + std::vector msgs; if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) { return -1; } @@ -20345,24 +20358,18 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, for (size_t i = 0; i < msgs.size(); i++) { deserialize((const char *)msgs[i].data(), &dev_info_ptr[i]); } + GGML_ASSERT(msgs.size() == n_world); }else{ - char * buffer = nullptr; - for(size_t i = 0; i < n_world; i++) { - size_t buffer_size = serialize(&dev_info_set[i], &buffer); - msgs.emplace_back(buffer, buffer_size); - - free(buffer); - } dev_info_ptr = dev_info_set; } GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr); - GGML_ASSERT(msgs.size() == n_world); // notify next rank auto next_rank = (my_rank + 1) % n_world; if(n_layer_window[next_rank] <= 0 && next_rank != 0){ try { + auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); ctx->send_socket->set(zmq::sockopt::linger, 3500); zmq::send_multipart(*ctx->send_socket, msgs); } catch (const zmq::error_t& e) { @@ -20394,6 +20401,7 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, ctx->next_node_ip = next_ip; try { ctx->send_socket->connect(send_endp); + auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); zmq::send_multipart(*ctx->send_socket, msgs); } catch (const zmq::error_t &e) { LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what()); @@ -20477,7 +20485,7 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, - uint32_t n_world) { + uint32_t n_world) { if(ctx) { ctx->cparams.rank = rank; ctx->cparams.n_world = n_world;