From 8b61cb2fa4b30952bc303a587dc1d77bfc8af5dd Mon Sep 17 00:00:00 2001 From: DeEMO Date: Fri, 16 May 2025 17:03:36 +0800 Subject: [PATCH] fix: adapt the new topo Signed-off-by: DeEMO --- common/common.cpp | 8 ++++++++ examples/main/main.cpp | 7 +++++-- src/llama.cpp | 6 +++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a98337d3..35d285c6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1731,6 +1731,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { n_gpu_layers[i] = n_gpu_layers_temp[i]; } llama_update_context_with_rankworld(lctx, update_rank, update_n_world); + cparams.rank = update_rank; + cparams.n_world = update_n_world; + mparams.rank = update_rank; + mparams.n_world = update_n_world; + params.rank = update_rank; + params.n_world = update_n_world; + my_rank = update_rank; + n_world = update_n_world; // update n_layer_window and n_gpu_layers std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 39d4b60c..04680373 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -143,8 +143,8 @@ int main(int argc, char ** argv) { return 1; } - const uint32_t n_world = params.n_world; - const uint32_t my_rank = params.rank; + uint32_t n_world = params.n_world; + uint32_t my_rank = params.rank; GGML_ASSERT(!(n_world == 1 && my_rank > 0)); // check if --n-layer-window and --world is matched @@ -200,6 +200,9 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); llama_init_result llama_init = llama_init_from_gpt_params(params); + // update + my_rank = params.rank; + n_world = params.n_world; model = llama_init.model; ctx = llama_init.context; diff --git a/src/llama.cpp b/src/llama.cpp index a0b91edd..f083e8e5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2585,6 +2585,7 @@ static_assert(std::is_trivially_copyable::value, "llama_hparams m struct llama_cparams { uint32_t n_world; uint32_t rank; + uint32_t original_next_rank; // original rank of the next node uint32_t n_layer_window[32]; bool prefetch; bool force; @@ -20399,6 +20400,7 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx, ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push); std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port)); ctx->next_node_ip = next_ip; + ctx->cparams.original_next_rank = next_rank; try { ctx->send_socket->connect(send_endp); auto msgs = dev_infos_to_messages(dev_info_ptr, n_world); @@ -20457,7 +20459,8 @@ int llama_recv_layer_setup(struct llama_context * ctx, uint32_t * n_layer_window void llama_free_sockets(struct llama_context * ctx, char ** msg) { const uint32_t n_world = ctx->cparams.n_world; const uint32_t my_rank = ctx->cparams.rank; - const uint32_t next_rank = (my_rank + 1) % n_world; + // to adapt to the new topology, use old next_rank + const uint32_t next_rank = ctx->cparams.original_next_rank; if (n_world == 1) { return; @@ -20508,6 +20511,7 @@ struct llama_context * llama_new_context_with_model( ctx->cparams.n_world = params.n_world; ctx->cparams.rank = params.rank; ctx->cparams.force = params.force; + ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world; return ctx; }