mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 00:54:41 +00:00
added override tensor
This commit is contained in:
parent
17360a3b32
commit
2ed6850c0b
3 changed files with 59 additions and 3 deletions
|
@ -2172,6 +2172,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
|
||||
std::vector<llama_model_kv_override> kvos; //ensure it keeps in scope until model is created
|
||||
std::vector<llama_model_tensor_buft_override> tenos; //ensure it keeps in scope until model is created
|
||||
std::vector<std::string> temp_tensor_names; //store temp tensor names to have mem references.
|
||||
if(inputs.moe_experts>0)
|
||||
{
|
||||
printf("\nOverriding number of experts to %d\n",inputs.moe_experts);
|
||||
|
@ -2195,13 +2197,58 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
{
|
||||
printf("\nAttempting to apply KV override: %s...\n",override_kv.c_str());
|
||||
bool kvo_ok = string_parse_kv_override(override_kv.c_str(),kvos);
|
||||
LLAMA_LOG_INFO("\nKV override result: %s\n",(kvo_ok?"success":"failed"));
|
||||
LLAMA_LOG_INFO("\nKV override parse: %s\n",(kvo_ok?"success":"failed"));
|
||||
fflush(stdout);
|
||||
}
|
||||
if(kvos.size()>0)
|
||||
{
|
||||
kvos.emplace_back();
|
||||
kvos.back().key[0] = 0;
|
||||
model_params.kv_overrides = kvos.data();
|
||||
}
|
||||
//handle override tensor
|
||||
std::string tensoroverrides = inputs.override_tensors;
|
||||
if(tensoroverrides!="" && ggml_backend_dev_count()>1)
|
||||
{
|
||||
printf("Handling Override Tensors for backends: ");
|
||||
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
auto * buft = ggml_backend_dev_buffer_type(dev);
|
||||
if (buft) {
|
||||
std::string name = ggml_backend_buft_name(buft);
|
||||
printf("%s ", name.c_str());
|
||||
buft_list[name] = buft;
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
for (const auto & override : string_split<std::string>(tensoroverrides, ',')) {
|
||||
std::string::size_type pos = override.find('=');
|
||||
if (pos == std::string::npos) {
|
||||
printf("\nInvalid Override Tensor: %s\n",override.c_str());
|
||||
continue;
|
||||
}
|
||||
std::string tensor_name = override.substr(0, pos);
|
||||
std::string buffer_type = override.substr(pos + 1);
|
||||
|
||||
if (buft_list.find(buffer_type) == buft_list.end()) {
|
||||
printf("\nUnknown Buffer Type: %s\n",buffer_type.c_str());
|
||||
continue;
|
||||
}
|
||||
llama_model_tensor_buft_override nto;
|
||||
temp_tensor_names.push_back(tensor_name);
|
||||
nto.pattern = temp_tensor_names[temp_tensor_names.size()-1].c_str();
|
||||
nto.buft = buft_list.at(buffer_type);
|
||||
tenos.push_back(nto);
|
||||
printf("Override Tensor: %s to %s\n",tensor_name.c_str(),buffer_type.c_str());
|
||||
}
|
||||
}
|
||||
if(tenos.size()>0)
|
||||
{
|
||||
tenos.push_back({nullptr, nullptr});
|
||||
model_params.tensor_buft_overrides = tenos.data();
|
||||
}
|
||||
|
||||
llama_model * llamamodel = llama_model_load_from_file(kcpp_data->model_filename.c_str(), model_params);
|
||||
|
||||
if(overwriteRope)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue