diff --git a/common/common.cpp b/common/common.cpp index 86de0c06..9c373a9a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -896,7 +896,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { device_info dev_info; dev_info.rank = params.rank; - llama_profile_device(&dev_info, model, ml, params.cpuparams.n_threads); + llama_profile_device(&dev_info, model, ml, params.n_predict, params.cpuparams.n_threads); // create llama context struct llama_context_params cparams = llama_context_params_from_gpt_params(params); diff --git a/include/llama.h b/include/llama.h index e21b056a..886f696d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -415,6 +415,7 @@ extern "C" { struct device_info * dev_info, struct llama_model * model, struct llama_model_loader * ml, + int n_predict, int n_threads); LLAMA_API ggml_backend_buffer_type_t llama_dev_buffer_type(struct llama_model * model, int device); diff --git a/src/llama.cpp b/src/llama.cpp index ad6b41be..f5f9889f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3570,7 +3570,7 @@ static bool is_dtype_exist(struct model_params * n_params, enum ggml_type dtype) } } -void llama_profile_device(device_info * dev_info, struct llama_model * model, llama_model_loader * ml, int n_threads) { +void llama_profile_device(device_info * dev_info, struct llama_model * model, llama_model_loader * ml, int n_predict, int n_threads) { dev_info->device_name = device_name(); dev_info->cpu_props.cores = device_cpu_cores(); @@ -3584,7 +3584,7 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll struct model_params * n_params = &dev_info->model_params; if (dev_info->rank == 0) { enum ggml_type inp_embd_dtype = GGML_TYPE_F32; - llama_model_n_flops(model, ml, n_flops, n_params, 1, 32, &inp_embd_dtype); + llama_model_n_flops(model, ml, n_flops, n_params, 1, n_predict, &inp_embd_dtype); n_flops->inp_embd_ms = device_inp_embd_delay(model, inp_embd_dtype, 1, n_threads); }