fix: send and recv meta

This commit is contained in:
DeEMO 2025-06-11 21:05:31 +08:00 committed by DeEMO
parent d6c8d322cd
commit 2039e3b0c1
3 changed files with 27 additions and 4 deletions

View file

@ -2597,6 +2597,9 @@ static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams m
struct llama_cparams {
uint32_t n_world;
uint32_t rank;
NodeType node_type;
uint32_t n_worker;
uint32_t worker_rank;
uint32_t original_next_rank; // original rank of the next node
uint32_t n_layer_window[32];
bool prefetch;
@ -18213,6 +18216,9 @@ static int llama_decode_internal(
const uint32_t n_world = cparams.n_world;
const uint32_t my_rank = cparams.rank;
const uint32_t n_worker = cparams.n_worker;
const uint32_t worker_rank = cparams.worker_rank;
lctx.is_encoding = false;
const uint32_t n_tokens_all = batch_all.n_tokens;
if (my_rank != 0) {
@ -18268,7 +18274,7 @@ static int llama_decode_internal(
sync_meta meta;
meta.n_ctx = cparams.n_ctx;
bool is_last_dev = (my_rank == n_world - 1);
bool is_last_dev = (worker_rank == n_worker - 1);
if (my_rank != 0) {
if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) {
@ -20757,6 +20763,7 @@ int llama_rebuild_topo(llama_context * ctx,
for(size_t i = 0; i < n_world; i++) {
is_forwarder[i] = topo_helper[i].is_forwarder;
}
ctx->cparams.node_type = *node_type;
if (socket_to_close != nullptr) {
socket_to_close->close();
@ -20842,10 +20849,16 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
}
}
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
void llama_update_context_with_rankworld(struct llama_context * ctx,
uint32_t rank,
uint32_t n_world,
uint32_t worker_rank,
uint32_t n_worker) {
if (ctx) {
ctx->cparams.rank = rank;
ctx->cparams.n_world = n_world;
ctx->cparams.worker_rank = worker_rank;
ctx->cparams.n_worker = n_worker;
}
}