allow overriding the devices directly

This commit is contained in:
Concedo 2026-01-17 19:08:06 +08:00
parent 21e6ccb8cb
commit 62bea5ef4f
7 changed files with 92 additions and 5 deletions

View file

@ -80,6 +80,7 @@ struct load_model_inputs
const int smartcacheslots = 0;
const bool pipelineparallel = false;
const float lora_multiplier = 1.0f;
const char * devices_override = nullptr;
const bool quiet = false;
const int debugmode = 0;
};
@ -192,6 +193,7 @@ struct sd_load_model_inputs
const char * photomaker_filename = nullptr;
const int img_hard_limit = 0;
const int img_soft_limit = 0;
const char * devices_override = nullptr;
const bool quiet = false;
const int debugmode = 0;
};
@ -241,6 +243,7 @@ struct whisper_load_model_inputs
const int clblast_info = 0;
const int kcpp_main_gpu = 0;
const char * vulkan_info = nullptr;
const char * devices_override = nullptr;
const bool quiet = false;
const int debugmode = 0;
};
@ -269,6 +272,7 @@ struct tts_load_model_inputs
const int gpulayers = 0;
const bool flash_attention = false;
const int ttsmaxlen = 4096;
const char * devices_override = nullptr;
const bool quiet = false;
const int debugmode = 0;
};
@ -299,6 +303,7 @@ struct embeddings_load_model_inputs
const bool flash_attention = false;
const bool use_mmap = false;
const int embeddingsmaxctx = 0;
const char * devices_override = nullptr;
const bool quiet = false;
const int debugmode = 0;
};

View file

@ -643,6 +643,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
draft_model_params.use_mlock = base_model_params.use_mlock;
draft_model_params.use_direct_io = base_model_params.use_direct_io;
draft_model_params.n_gpu_layers = draft_gpulayers; //layers offload the speculative model.
draft_model_params.devices = base_model_params.devices;
draft_ctx_params.n_ctx = base_ctx_params.n_ctx;
draft_ctx_params.offload_kqv = base_ctx_params.offload_kqv;
draft_model_params.main_gpu = base_model_params.main_gpu;
@ -2375,6 +2376,19 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
model_params.use_direct_io = false; //no direct io for now until stable
model_params.n_gpu_layers = inputs.gpulayers;
//set device overrides if needed
std::vector<ggml_backend_dev_t> devices_override;
std::string dev_override_str = inputs.devices_override;
if(dev_override_str!="")
{
devices_override = kcpp_parse_device_list(dev_override_str);
if(devices_override.size()>0)
{
printf("\nOverriding with %d devices...\n",devices_override.size()-1);
model_params.devices = devices_override.data();
}
}
#if defined(GGML_USE_CLBLAST)
if(file_format==FileFormat::GGUF_GENERIC && model_params.n_gpu_layers>0)
{
@ -2483,11 +2497,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
//handle override tensor
std::string tensoroverrides = inputs.override_tensors;
// if(file_format_meta.model_architecture==GGUFArch::ARCH_GEMMA3N)
// {
// std::string forced = "per_layer_token_embd.weight=CPU"; //this tensor on gpu is problematic on unsloth q4_0
// tensoroverrides = (tensoroverrides=="" ? forced: (forced+","+tensoroverrides));
// }
if(ggml_backend_dev_count()>1 && inputs.moecpu>0)
{
std::string toadd = "";

View file

@ -226,6 +226,7 @@ class load_model_inputs(ctypes.Structure):
("smartcacheslots", ctypes.c_int),
("pipelineparallel", ctypes.c_bool),
("lora_multiplier", ctypes.c_float),
("devices_override", ctypes.c_char_p),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
@ -317,6 +318,7 @@ class sd_load_model_inputs(ctypes.Structure):
("photomaker_filename", ctypes.c_char_p),
("img_hard_limit", ctypes.c_int),
("img_soft_limit", ctypes.c_int),
("devices_override", ctypes.c_char_p),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
@ -361,6 +363,7 @@ class whisper_load_model_inputs(ctypes.Structure):
("clblast_info", ctypes.c_int),
("kcpp_main_gpu", ctypes.c_int),
("vulkan_info", ctypes.c_char_p),
("devices_override", ctypes.c_char_p),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
@ -385,6 +388,7 @@ class tts_load_model_inputs(ctypes.Structure):
("gpulayers", ctypes.c_int),
("flash_attention", ctypes.c_bool),
("ttsmaxlen", ctypes.c_int),
("devices_override", ctypes.c_char_p),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
@ -411,6 +415,7 @@ class embeddings_load_model_inputs(ctypes.Structure):
("flash_attention", ctypes.c_bool),
("use_mmap", ctypes.c_bool),
("embeddingsmaxctx", ctypes.c_int),
("devices_override", ctypes.c_char_p),
("quiet", ctypes.c_bool),
("debugmode", ctypes.c_int)]
@ -833,6 +838,7 @@ def set_backend_props(inputs):
inputs.vulkan_info = "".encode("UTF-8")
# set universal flags
inputs.devices_override = (args.device if args.device else "").encode("UTF-8")
inputs.quiet = args.quiet
inputs.debugmode = args.debugmode
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
@ -8832,6 +8838,7 @@ if __name__ == '__main__':
advparser.add_argument("--gendefaults", metavar=('{"parameter":"value",...}'), help="Sets extra default parameters for some fields in API requests, as a JSON string.", default="")
advparser.add_argument("--gendefaultsoverwrite", help="Allow the gendefaults parameters to overwrite the original value in API payloads.", action='store_true')
advparser.add_argument("--mcpfile", metavar=('[mcp json file]'), help="Specify path to mcp.json which contains the Cladue Desktop compatible MCP server config.", default="")
advparser.add_argument("--device", "-dev", metavar=('<dev1,dev2,..>'), help="Set llama.cpp compatible device selection override. Comma separated. Overrides normal device choices.", default="")
hordeparsergroup = parser.add_argument_group('Horde Worker Commands')
hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="")

View file

@ -102,6 +102,12 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs)
}
}
const char* existingenv = getenv("GGML_VK_VISIBLE_DEVICES");
std::vector<ggml_backend_dev_t> devices_override;
std::string dev_override_str = inputs.devices_override;
if(dev_override_str!="")
{
devices_override = kcpp_parse_device_list(dev_override_str);
}
if(!existingenv && vulkan_info_str!="")
{
ttsvulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str;
@ -125,6 +131,12 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs)
model_params.main_gpu = kcpp_parseinfo_maindevice;
model_params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER;
if(devices_override.size()>0)
{
printf("\nOverriding with %d devices...\n",devices_override.size()-1);
model_params.devices = devices_override.data();
}
llama_model * embeddingsmodel = llama_model_load_from_file(modelfile.c_str(), model_params);
const int n_ctx_train = llama_model_n_ctx_train(embeddingsmodel);

View file

@ -637,6 +637,12 @@ bool ttstype_load_model(const tts_load_model_inputs inputs)
}
}
const char* existingenv = getenv("GGML_VK_VISIBLE_DEVICES");
std::vector<ggml_backend_dev_t> devices_override;
std::string dev_override_str = inputs.devices_override;
if(dev_override_str!="")
{
devices_override = kcpp_parse_device_list(dev_override_str);
}
if(!existingenv && vulkan_info_str!="")
{
ttsvulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str;
@ -699,6 +705,12 @@ bool ttstype_load_model(const tts_load_model_inputs inputs)
tts_ctx_params.flash_attn_type = (inputs.flash_attention?LLAMA_FLASH_ATTN_TYPE_ENABLED:LLAMA_FLASH_ATTN_TYPE_DISABLED);
tts_ctx_params.kv_unified = true;
if(devices_override.size()>0)
{
printf("\nOverriding with %d devices...\n",devices_override.size()-1);
tts_model_params.devices = devices_override.data();
}
llama_model * ttcmodel = llama_model_load_from_file(modelfile_ttc.c_str(), tts_model_params);
ttc_ctx = llama_init_from_model(ttcmodel, tts_ctx_params);

View file

@ -709,4 +709,43 @@ bool kcpp_decode_audio_from_buf(const unsigned char * buf_in, size_t len, int ta
ma_decoder_uninit(&decoder);
return true;
}
static std::vector<std::string> kcpp_string_split(const std::string & input, char separator)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
return parts;
}
//for llama.cpp style device overrides e.g. --device Vulkan0,Vulkan1
std::vector<ggml_backend_dev_t> kcpp_parse_device_list(const std::string & value) {
std::vector<ggml_backend_dev_t> devices;
auto dev_names = kcpp_string_split(value, ',');
if (dev_names.empty()) {
printf("\nkcpp_parse_device_list error: no devices specified\n");
return std::vector<ggml_backend_dev_t>();
}
if (dev_names.size() == 1 && dev_names[0] == "none") {
return std::vector<ggml_backend_dev_t>();
} else {
for (const auto & device : dev_names) {
auto * dev = ggml_backend_dev_by_name(device.c_str());
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
printf("\nkcpp_parse_device_list error: invalid device: %s\n",device.c_str());
return std::vector<ggml_backend_dev_t>();
}
devices.push_back(dev);
}
devices.push_back(nullptr);
}
return devices;
}

View file

@ -69,6 +69,8 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<
std::vector<std::string> split_string(const std::string& input, const std::string& separator);
bool kcpp_decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono);
std::vector<ggml_backend_dev_t> kcpp_parse_device_list(const std::string & value);
//duplcated and modified from llava_embd_batch
struct kcpp_embd_batch {
std::vector<llama_pos> pos;