refactor: add zmq helper to generate message

Signed-off-by: DeEMO <yzzxrx@gmail.com>
This commit is contained in:
DeEMO 2025-05-16 16:02:25 +08:00
parent 0ad009a2f4
commit df16b1876f

View file

@ -161,6 +161,19 @@ static void zeros(std::ofstream & file, size_t n) {
}
}
// zmq helpers
static std::vector<zmq::message_t> dev_infos_to_messages(const device_info* infos,
uint32_t n_world){
std::vector<zmq::message_t> res;
for (uint32_t i = 0; i < n_world; ++i) {
char * buffer = nullptr;
size_t buffer_size = serialize(&infos[i], &buffer);
res.emplace_back(buffer, buffer_size);
free(buffer);
}
return res;
}
LLAMA_ATTRIBUTE_FORMAT(1, 2)
static std::string format(const char * fmt, ...) {
va_list ap;
@ -20334,10 +20347,10 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
device_info *dev_info_set) {
uint32_t n_world = ctx->cparams.n_world;
uint32_t my_rank = ctx->cparams.rank;
std::vector<zmq::message_t> msgs;
device_info* dev_info_ptr = nullptr;
if (dev_info_set == nullptr){
// for rank!=0, recv all devices info
std::vector<zmq::message_t> msgs;
if (!zmq::recv_multipart(*ctx->recv_socket, std::back_inserter(msgs))) {
return -1;
}
@ -20345,24 +20358,18 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
for (size_t i = 0; i < msgs.size(); i++) {
deserialize((const char *)msgs[i].data(), &dev_info_ptr[i]);
}
GGML_ASSERT(msgs.size() == n_world);
}else{
char * buffer = nullptr;
for(size_t i = 0; i < n_world; i++) {
size_t buffer_size = serialize(&dev_info_set[i], &buffer);
msgs.emplace_back(buffer, buffer_size);
free(buffer);
}
dev_info_ptr = dev_info_set;
}
GGML_ASSERT(ctx != nullptr && ctx->send_socket != nullptr);
GGML_ASSERT(msgs.size() == n_world);
// notify next rank
auto next_rank = (my_rank + 1) % n_world;
if(n_layer_window[next_rank] <= 0 && next_rank != 0){
try {
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
ctx->send_socket->set(zmq::sockopt::linger, 3500);
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t& e) {
@ -20394,6 +20401,7 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
ctx->next_node_ip = next_ip;
try {
ctx->send_socket->connect(send_endp);
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
zmq::send_multipart(*ctx->send_socket, msgs);
} catch (const zmq::error_t &e) {
LLAMA_LOG_INFO("Error binding/connecting recv socket to endpoint: %s", e.what());
@ -20477,7 +20485,7 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
void llama_update_context_with_rankworld(struct llama_context * ctx,
uint32_t rank,
uint32_t n_world) {
uint32_t n_world) {
if(ctx) {
ctx->cparams.rank = rank;
ctx->cparams.n_world = n_world;