mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 23:49:04 +00:00
fix: adapt the new topo
Signed-off-by: DeEMO <yzzxrx@gmail.com>
This commit is contained in:
parent
df16b1876f
commit
8b61cb2fa4
3 changed files with 18 additions and 3 deletions
|
@ -1731,6 +1731,14 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
n_gpu_layers[i] = n_gpu_layers_temp[i];
|
||||
}
|
||||
llama_update_context_with_rankworld(lctx, update_rank, update_n_world);
|
||||
cparams.rank = update_rank;
|
||||
cparams.n_world = update_n_world;
|
||||
mparams.rank = update_rank;
|
||||
mparams.n_world = update_n_world;
|
||||
params.rank = update_rank;
|
||||
params.n_world = update_n_world;
|
||||
my_rank = update_rank;
|
||||
n_world = update_n_world;
|
||||
|
||||
// update n_layer_window and n_gpu_layers
|
||||
std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window);
|
||||
|
|
|
@ -143,8 +143,8 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
const uint32_t n_world = params.n_world;
|
||||
const uint32_t my_rank = params.rank;
|
||||
uint32_t n_world = params.n_world;
|
||||
uint32_t my_rank = params.rank;
|
||||
GGML_ASSERT(!(n_world == 1 && my_rank > 0));
|
||||
|
||||
// check if --n-layer-window and --world is matched
|
||||
|
@ -200,6 +200,9 @@ int main(int argc, char ** argv) {
|
|||
// load the model and apply lora adapter, if any
|
||||
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||
// update
|
||||
my_rank = params.rank;
|
||||
n_world = params.n_world;
|
||||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
|
|
|
@ -2585,6 +2585,7 @@ static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams m
|
|||
struct llama_cparams {
|
||||
uint32_t n_world;
|
||||
uint32_t rank;
|
||||
uint32_t original_next_rank; // original rank of the next node
|
||||
uint32_t n_layer_window[32];
|
||||
bool prefetch;
|
||||
bool force;
|
||||
|
@ -20399,6 +20400,7 @@ LLAMA_API int llama_rebuild_topo(llama_context *ctx,
|
|||
ctx->send_socket = new zmq::socket_t(*ctx->sock_context, zmq::socket_type::push);
|
||||
std::string send_endp = "tcp://" + next_ip + ":" + std::to_string(map_rank_to_port(next_rank, ctx->data_port));
|
||||
ctx->next_node_ip = next_ip;
|
||||
ctx->cparams.original_next_rank = next_rank;
|
||||
try {
|
||||
ctx->send_socket->connect(send_endp);
|
||||
auto msgs = dev_infos_to_messages(dev_info_ptr, n_world);
|
||||
|
@ -20457,7 +20459,8 @@ int llama_recv_layer_setup(struct llama_context * ctx, uint32_t * n_layer_window
|
|||
void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
||||
const uint32_t n_world = ctx->cparams.n_world;
|
||||
const uint32_t my_rank = ctx->cparams.rank;
|
||||
const uint32_t next_rank = (my_rank + 1) % n_world;
|
||||
// to adapt to the new topology, use old next_rank
|
||||
const uint32_t next_rank = ctx->cparams.original_next_rank;
|
||||
|
||||
if (n_world == 1) {
|
||||
return;
|
||||
|
@ -20508,6 +20511,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->cparams.n_world = params.n_world;
|
||||
ctx->cparams.rank = params.rank;
|
||||
ctx->cparams.force = params.force;
|
||||
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue