added override tensor

This commit is contained in:
Concedo 2025-04-20 20:56:17 +08:00
parent 17360a3b32
commit 2ed6850c0b
3 changed files with 59 additions and 3 deletions

View file

@ -2172,6 +2172,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
std::vector<llama_model_kv_override> kvos; //ensure it keeps in scope until model is created
std::vector<llama_model_tensor_buft_override> tenos; //ensure it keeps in scope until model is created
std::vector<std::string> temp_tensor_names; //store temp tensor names to have mem references.
if(inputs.moe_experts>0)
{
printf("\nOverriding number of experts to %d\n",inputs.moe_experts);
@ -2195,13 +2197,58 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
printf("\nAttempting to apply KV override: %s...\n",override_kv.c_str());
bool kvo_ok = string_parse_kv_override(override_kv.c_str(),kvos);
LLAMA_LOG_INFO("\nKV override result: %s\n",(kvo_ok?"success":"failed"));
LLAMA_LOG_INFO("\nKV override parse: %s\n",(kvo_ok?"success":"failed"));
fflush(stdout);
}
if(kvos.size()>0)
{
kvos.emplace_back();
kvos.back().key[0] = 0;
model_params.kv_overrides = kvos.data();
}
//handle override tensor
std::string tensoroverrides = inputs.override_tensors;
if(tensoroverrides!="" && ggml_backend_dev_count()>1)
{
printf("Handling Override Tensors for backends: ");
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
std::string name = ggml_backend_buft_name(buft);
printf("%s ", name.c_str());
buft_list[name] = buft;
}
}
printf("\n\n");
for (const auto & override : string_split<std::string>(tensoroverrides, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
printf("\nInvalid Override Tensor: %s\n",override.c_str());
continue;
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("\nUnknown Buffer Type: %s\n",buffer_type.c_str());
continue;
}
llama_model_tensor_buft_override nto;
temp_tensor_names.push_back(tensor_name);
nto.pattern = temp_tensor_names[temp_tensor_names.size()-1].c_str();
nto.buft = buft_list.at(buffer_type);
tenos.push_back(nto);
printf("Override Tensor: %s to %s\n",tensor_name.c_str(),buffer_type.c_str());
}
}
if(tenos.size()>0)
{
tenos.push_back({nullptr, nullptr});
model_params.tensor_buft_overrides = tenos.data();
}
llama_model * llamamodel = llama_model_load_from_file(kcpp_data->model_filename.c_str(), model_params);
if(overwriteRope)