mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 23:39:05 +00:00
fix: send and recv meta
This commit is contained in:
parent
d6c8d322cd
commit
2039e3b0c1
3 changed files with 27 additions and 4 deletions
|
@ -1761,6 +1761,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
|
||||
// update my rank and n_world
|
||||
uint32_t update_rank = 0, update_n_world = 1;
|
||||
uint32_t worker_rank = 0, n_worker = 1;
|
||||
std::vector<uint32_t> n_layer_window_temp = {n_layer_window[0]}, n_gpu_layers_temp = {n_gpu_layers[0]};
|
||||
|
||||
for (uint32_t i = 1; i < n_world; i++) {
|
||||
|
@ -1773,6 +1774,13 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
update_n_world++;
|
||||
n_layer_window_temp.push_back(n_layer_window[i]);
|
||||
n_gpu_layers_temp.push_back(n_gpu_layers[i]);
|
||||
|
||||
if (n_layer_window[i] > 0) {
|
||||
if (i <= my_rank) {
|
||||
worker_rank++;
|
||||
}
|
||||
n_worker++;
|
||||
}
|
||||
}
|
||||
|
||||
memset(n_layer_window, 0, n_world * sizeof(uint32_t));
|
||||
|
@ -1795,7 +1803,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
params.n_world = update_n_world;
|
||||
n_world = update_n_world;
|
||||
|
||||
llama_update_context_with_rankworld(lctx, update_rank, update_n_world);
|
||||
llama_update_context_with_rankworld(lctx, update_rank, update_n_world, worker_rank, n_worker);
|
||||
|
||||
if(node_type == NodeType::NODE_TYPE_FORWARDER){
|
||||
//just foward
|
||||
|
|
|
@ -477,7 +477,9 @@ extern "C" {
|
|||
LLAMA_API void llama_update_context_with_rankworld(
|
||||
struct llama_context * ctx,
|
||||
uint32_t rank,
|
||||
uint32_t n_world);
|
||||
uint32_t n_world,
|
||||
uint32_t worker_rank,
|
||||
uint32_t n_worker);
|
||||
|
||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
struct llama_model * model,
|
||||
|
|
|
@ -2597,6 +2597,9 @@ static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams m
|
|||
struct llama_cparams {
|
||||
uint32_t n_world;
|
||||
uint32_t rank;
|
||||
NodeType node_type;
|
||||
uint32_t n_worker;
|
||||
uint32_t worker_rank;
|
||||
uint32_t original_next_rank; // original rank of the next node
|
||||
uint32_t n_layer_window[32];
|
||||
bool prefetch;
|
||||
|
@ -18213,6 +18216,9 @@ static int llama_decode_internal(
|
|||
const uint32_t n_world = cparams.n_world;
|
||||
const uint32_t my_rank = cparams.rank;
|
||||
|
||||
const uint32_t n_worker = cparams.n_worker;
|
||||
const uint32_t worker_rank = cparams.worker_rank;
|
||||
|
||||
lctx.is_encoding = false;
|
||||
const uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
if (my_rank != 0) {
|
||||
|
@ -18268,7 +18274,7 @@ static int llama_decode_internal(
|
|||
|
||||
sync_meta meta;
|
||||
meta.n_ctx = cparams.n_ctx;
|
||||
bool is_last_dev = (my_rank == n_world - 1);
|
||||
bool is_last_dev = (worker_rank == n_worker - 1);
|
||||
|
||||
if (my_rank != 0) {
|
||||
if (llama_recv_meta(*lctx.recv_socket, &meta) == -1) {
|
||||
|
@ -20757,6 +20763,7 @@ int llama_rebuild_topo(llama_context * ctx,
|
|||
for(size_t i = 0; i < n_world; i++) {
|
||||
is_forwarder[i] = topo_helper[i].is_forwarder;
|
||||
}
|
||||
ctx->cparams.node_type = *node_type;
|
||||
|
||||
if (socket_to_close != nullptr) {
|
||||
socket_to_close->close();
|
||||
|
@ -20842,10 +20849,16 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_update_context_with_rankworld(struct llama_context * ctx, uint32_t rank, uint32_t n_world) {
|
||||
void llama_update_context_with_rankworld(struct llama_context * ctx,
|
||||
uint32_t rank,
|
||||
uint32_t n_world,
|
||||
uint32_t worker_rank,
|
||||
uint32_t n_worker) {
|
||||
if (ctx) {
|
||||
ctx->cparams.rank = rank;
|
||||
ctx->cparams.n_world = n_world;
|
||||
ctx->cparams.worker_rank = worker_rank;
|
||||
ctx->cparams.n_worker = n_worker;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue