mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-06 19:49:02 +00:00
fix w init error
This commit is contained in:
parent
4948b1004c
commit
9e4ba4f06a
1 changed files with 26 additions and 6 deletions
|
@ -832,7 +832,7 @@ std::string fs_get_cache_file(const std::string & filename) {
|
|||
return cache_directory + filename;
|
||||
}
|
||||
|
||||
static void assign_device(
|
||||
static bool assign_device(
|
||||
uint32_t n_world,
|
||||
uint32_t my_rank,
|
||||
const device_info * dev_info_set,
|
||||
|
@ -849,7 +849,7 @@ static void assign_device(
|
|||
const uint32_t n_layer = llama_model_n_layers(model);
|
||||
if (n_world == 1) {
|
||||
n_layer_window[0] = n_layer;
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
const device_info &master = dev_info_set[0];
|
||||
|
@ -958,6 +958,11 @@ static void assign_device(
|
|||
w[m] = std::round(mem_budget[m] / total_mem_budget * n_layer);
|
||||
n[m] = 0;
|
||||
}
|
||||
// adjust w[m] to ensure L mod W = 0
|
||||
int diff = n_layer - std::accumulate(w.begin(), w.end(), 0);
|
||||
auto device = (diff > 0) ? std::max_element(mem_budget.begin(), mem_budget.end())
|
||||
: std::min_element(mem_budget.begin(), mem_budget.end());
|
||||
w[std::distance(mem_budget.begin(), device)] += diff;
|
||||
|
||||
#if defined(USE_HIGHS)
|
||||
// stores the actual read bandwidth (GB/s) for each device
|
||||
|
@ -1066,7 +1071,13 @@ static void assign_device(
|
|||
while (true) {
|
||||
int W = std::accumulate(w.begin(), w.end(), 0);
|
||||
int cur_k = (int)n_layer / W;
|
||||
GGML_ASSERT(W > 1 && (int)n_layer % W == 0 && "Constraint: L = k * W must hold\n");
|
||||
|
||||
if (W <= 1 || (int)n_layer % W != 0) {
|
||||
LOG_INF("Constraint: L = k * W must hold, but W = %d, L = %d\n", W, n_layer);
|
||||
fflush(stdout);
|
||||
fflush(stderr);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!assign_sets(cur_k)) break;
|
||||
|
||||
|
@ -1380,6 +1391,8 @@ static void assign_device(
|
|||
// copy value from w and n to n_layer_window and n_gpu_layers, respectively
|
||||
std::copy(w.begin(), w.end(), n_layer_window);
|
||||
std::copy(n.begin(), n.end(), n_gpu_layers);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -1465,7 +1478,13 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
uint32_t n_layer_window[32] = {0}, n_gpu_layers[32] = {0};
|
||||
if (my_rank == 0) {
|
||||
// automatically determine n_layer_window and n_gpu_layers
|
||||
assign_device(n_world, my_rank, dev_info_set, n_layer_window, n_gpu_layers, model, cparams);
|
||||
if (!assign_device(n_world, my_rank, dev_info_set, n_layer_window, n_gpu_layers, model, cparams)) {
|
||||
LOG_ERR("%s: Invalid allocation by HiGHS solver\n", __func__);
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
// synchronize the new n_layer_window and n_gpu_layers to other nodes
|
||||
llama_bcast_layer_setup(lctx, n_layer_window, n_gpu_layers);
|
||||
} else {
|
||||
|
@ -1494,6 +1513,8 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
|
||||
if (!mparams.vocab_only && llm_load_tensors(ml, model, mparams) < 0) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
|
@ -1501,6 +1522,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
|
||||
if (llama_context_setup_backend(model, cparams, lctx) == nullptr) {
|
||||
LOG_ERR("%s: failed to setup context with model '%s'\n", __func__, params.model.c_str());
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
return iparams;
|
||||
}
|
||||
|
@ -1513,7 +1535,6 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
if (cvec.n_embd == -1) {
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
|
||||
|
@ -1526,7 +1547,6 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
if (err) {
|
||||
llama_free(lctx);
|
||||
llama_free_model(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue