synchronize device info

This commit is contained in:
Lizonghang 2024-11-07 22:02:01 +04:00
parent ef7fdf70cc
commit 53cb3a6069
5 changed files with 408 additions and 73 deletions

View file

@ -833,10 +833,6 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}
// profile devices and determine the best setup
device_info dev_info;
llama_profile_device(&dev_info, model, params.model.c_str());
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
return iparams;
@ -866,17 +862,35 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
}
}
auto cparams = llama_context_params_from_gpt_params(params);
// get device profile
device_info dev_info;
dev_info.rank = params.rank;
llama_profile_device(&dev_info, model, params.model.c_str());
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
// create llama context
struct llama_context_params cparams = llama_context_params_from_gpt_params(params);
llama_context * lctx = llama_new_context_with_model(model, cparams);
// initialize sockets
llama_init_sockets(lctx, cparams.n_world, cparams.rank);
// sychronize device profile to the master node
struct device_info * dev_info_set = nullptr;
if (params.rank == 0) {
dev_info_set = (struct device_info *)malloc(cparams.n_world * sizeof(struct device_info));
dev_info_set[0] = dev_info;
llama_collect_device_info(dev_info_set, lctx);
device_print_props(dev_info_set, cparams.n_world);
} else {
llama_send_device_info(&dev_info, lctx);
}
if (llama_context_setup_backend(lctx) == nullptr) {
LOG_ERR("%s: failed to setup context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return iparams;
}
llama_init_sockets(lctx, cparams.n_world, cparams.rank);
if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);