mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-07 01:29:02 +00:00
update rank and n_world
Signed-off-by: DeEMO <yzzxrx@gmail.com>
This commit is contained in:
parent
fdd6694633
commit
cc46aa9828
3 changed files with 38 additions and 1 deletions
|
@ -1702,6 +1702,30 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//update rank and n_world for consistency
|
||||||
|
uint32_t update_rank = 0;
|
||||||
|
uint32_t update_n_world = 1;
|
||||||
|
std::vector<uint32_t> n_layer_window_temp = {n_layer_window[0]};
|
||||||
|
std::vector<uint32_t> n_gpu_layers_temp = {n_gpu_layers[0]};
|
||||||
|
for(auto i=1; i<n_world; i++) {
|
||||||
|
if(n_layer_window[i] <= 0 ){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if(i <= my_rank){
|
||||||
|
update_rank++;
|
||||||
|
}
|
||||||
|
update_n_world++;
|
||||||
|
n_layer_window_temp.push_back(n_layer_window[i]);
|
||||||
|
n_gpu_layers_temp.push_back(n_gpu_layers[i]);
|
||||||
|
}
|
||||||
|
memset(n_layer_window, 0, n_world * sizeof(uint32_t));
|
||||||
|
memset(n_gpu_layers, 0, n_world * sizeof(uint32_t));
|
||||||
|
for (auto i=0; i<update_n_world; i++) {
|
||||||
|
n_layer_window[i] = n_layer_window_temp[i];
|
||||||
|
n_gpu_layers[i] = n_gpu_layers_temp[i];
|
||||||
|
}
|
||||||
|
llama_update_context_with_rankworld(lctx, update_rank, update_n_world);
|
||||||
|
|
||||||
// update n_layer_window and n_gpu_layers
|
// update n_layer_window and n_gpu_layers
|
||||||
std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window);
|
std::copy(std::begin(n_layer_window), std::end(n_layer_window), params.n_layer_window);
|
||||||
std::copy(std::begin(n_layer_window), std::end(n_layer_window), cparams.n_layer_window);
|
std::copy(std::begin(n_layer_window), std::end(n_layer_window), cparams.n_layer_window);
|
||||||
|
|
|
@ -462,7 +462,11 @@ extern "C" {
|
||||||
struct llama_model_loader * ml,
|
struct llama_model_loader * ml,
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_model_params params);
|
struct llama_model_params params);
|
||||||
|
|
||||||
|
LLAMA_API void llama_update_context_with_rankworld(struct llama_context * ctx,
|
||||||
|
uint32_t rank,
|
||||||
|
uint32_t n_world);
|
||||||
|
|
||||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_context_params params);
|
struct llama_context_params params);
|
||||||
|
|
|
@ -20475,6 +20475,15 @@ void llama_free_sockets(struct llama_context * ctx, char ** msg) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_update_context_with_rankworld(struct llama_context * ctx,
|
||||||
|
uint32_t rank,
|
||||||
|
uint32_t n_world) {
|
||||||
|
if(ctx) {
|
||||||
|
ctx->cparams.rank = rank;
|
||||||
|
ctx->cparams.n_world = n_world;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_context * llama_new_context_with_model(
|
struct llama_context * llama_new_context_with_model(
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_context_params params) {
|
struct llama_context_params params) {
|
||||||
|
|
Loading…
Add table
Reference in a new issue