llama_profile_device: add arg n_predict

This commit is contained in:
Lizonghang 2024-12-06 16:37:25 +04:00
parent a46d56cc60
commit cd823546dd
3 changed files with 4 additions and 3 deletions

View file

@ -3570,7 +3570,7 @@ static bool is_dtype_exist(struct model_params * n_params, enum ggml_type dtype)
}
}
void llama_profile_device(device_info * dev_info, struct llama_model * model, llama_model_loader * ml, int n_threads) {
void llama_profile_device(device_info * dev_info, struct llama_model * model, llama_model_loader * ml, int n_predict, int n_threads) {
dev_info->device_name = device_name();
dev_info->cpu_props.cores = device_cpu_cores();
@ -3584,7 +3584,7 @@ void llama_profile_device(device_info * dev_info, struct llama_model * model, ll
struct model_params * n_params = &dev_info->model_params;
if (dev_info->rank == 0) {
enum ggml_type inp_embd_dtype = GGML_TYPE_F32;
llama_model_n_flops(model, ml, n_flops, n_params, 1, 32, &inp_embd_dtype);
llama_model_n_flops(model, ml, n_flops, n_params, 1, n_predict, &inp_embd_dtype);
n_flops->inp_embd_ms = device_inp_embd_delay(model, inp_embd_dtype, 1, n_threads);
}