From 2039e3b0c1bc9eaa957256b85352e01daf7aa15f Mon Sep 17 00:00:00 2001 From: DeEMO Date: Wed, 11 Jun 2025 21:05:31 +0800 Subject: [PATCH] fix: send and recv meta --- common/common.cpp | 10 +++++++++- include/llama.h | 4 +++- src/llama.cpp | 17 +++++++++++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d40647ec..a21146b7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1761,6 +1761,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { // update my rank and n_world uint32_t update_rank = 0, update_n_world = 1; + uint32_t worker_rank = 0, n_worker = 1; std::vector n_layer_window_temp = {n_layer_window[0]}, n_gpu_layers_temp = {n_gpu_layers[0]}; for (uint32_t i = 1; i < n_world; i++) { @@ -1773,6 +1774,13 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { update_n_world++; n_layer_window_temp.push_back(n_layer_window[i]); n_gpu_layers_temp.push_back(n_gpu_layers[i]); + + if (n_layer_window[i] > 0) { + if (i <= my_rank) { + worker_rank++; + } + n_worker++; + } } memset(n_layer_window, 0, n_world * sizeof(uint32_t)); @@ -1795,7 +1803,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { params.n_world = update_n_world; n_world = update_n_world; - llama_update_context_with_rankworld(lctx, update_rank, update_n_world); + llama_update_context_with_rankworld(lctx, update_rank, update_n_world, worker_rank, n_worker); if(node_type == NodeType::NODE_TYPE_FORWARDER){ //just foward diff --git a/include/llama.h b/include/llama.h index 21f77288..7b63e96f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -477,7 +477,9 @@ extern "C" { LLAMA_API void llama_update_context_with_rankworld( struct llama_context * ctx, uint32_t rank, - uint32_t n_world); + uint32_t n_world, + uint32_t worker_rank, + uint32_t n_worker); LLAMA_API struct llama_context * llama_new_context_with_model( struct llama_model * model, diff --git a/src/llama.cpp b/src/llama.cpp index 281c7360..dd0540cf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2597,6 +2597,9 @@ static_assert(std::is_trivially_copyable::value, "llama_hparams m struct llama_cparams { uint32_t n_world; uint32_t rank; + NodeType node_type; + uint32_t n_worker; + uint32_t worker_rank; uint32_t original_next_rank; // original rank of the next node uint32_t n_layer_window[32]; bool prefetch; @@ -18213,6 +18216,9 @@ static int llama_decode_internal( const uint32_t n_world = cparams.n_world; const uint32_t my_rank = cparams.rank; + const uint32_t n_worker = cparams.n_worker; + const uint32_t worker_rank = cparams.worker_rank; + lctx.is_encoding = false; const uint32_t n_tokens_all = batch_all.n_tokens; if (my_rank != 0) { @@ -18268,7 +18274,7 @@ static int llama_decode_internal( sync_meta meta; meta.n_ctx = cparams.n_ctx; - bool is_last_dev = (my_rank == n_world - 1); + bool is_last_dev = (worker_rank == n_worker - 1); if (my_rank != 0) { if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) { @@ -20757,6 +20763,7 @@ int llama_rebuild_topo(llama_context * ctx, for(size_t i = 0; i < n_world; i++) { is_forwarder[i] = topo_helper[i].is_forwarder; } + ctx->cparams.node_type = *node_type; if (socket_to_close != nullptr) { socket_to_close->close(); @@ -20842,10 +20849,16 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) { } } -void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) { +void llama_update_context_with_rankworld(struct llama_context * ctx, + uint32_t rank, + uint32_t n_world, + uint32_t worker_rank, + uint32_t n_worker) { if (ctx) { ctx->cparams.rank = rank; ctx->cparams.n_world = n_world; + ctx->cparams.worker_rank = worker_rank; + ctx->cparams.n_worker = n_worker; } }