From 62bea5ef4ff08c0a118731fac3f084e7afcab316 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 17 Jan 2026 19:08:06 +0800 Subject: [PATCH] allow overriding the devices directly --- expose.h | 5 ++++ gpttype_adapter.cpp | 20 ++++++++++++---- koboldcpp.py | 7 ++++++ otherarch/embeddings_adapter.cpp | 12 ++++++++++ otherarch/tts_adapter.cpp | 12 ++++++++++ otherarch/utils.cpp | 39 ++++++++++++++++++++++++++++++++ otherarch/utils.h | 2 ++ 7 files changed, 92 insertions(+), 5 deletions(-) diff --git a/expose.h b/expose.h index 349da5856..27bc391cd 100644 --- a/expose.h +++ b/expose.h @@ -80,6 +80,7 @@ struct load_model_inputs const int smartcacheslots = 0; const bool pipelineparallel = false; const float lora_multiplier = 1.0f; + const char * devices_override = nullptr; const bool quiet = false; const int debugmode = 0; }; @@ -192,6 +193,7 @@ struct sd_load_model_inputs const char * photomaker_filename = nullptr; const int img_hard_limit = 0; const int img_soft_limit = 0; + const char * devices_override = nullptr; const bool quiet = false; const int debugmode = 0; }; @@ -241,6 +243,7 @@ struct whisper_load_model_inputs const int clblast_info = 0; const int kcpp_main_gpu = 0; const char * vulkan_info = nullptr; + const char * devices_override = nullptr; const bool quiet = false; const int debugmode = 0; }; @@ -269,6 +272,7 @@ struct tts_load_model_inputs const int gpulayers = 0; const bool flash_attention = false; const int ttsmaxlen = 4096; + const char * devices_override = nullptr; const bool quiet = false; const int debugmode = 0; }; @@ -299,6 +303,7 @@ struct embeddings_load_model_inputs const bool flash_attention = false; const bool use_mmap = false; const int embeddingsmaxctx = 0; + const char * devices_override = nullptr; const bool quiet = false; const int debugmode = 0; }; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 5960db87b..81f22e08f 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -643,6 +643,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll draft_model_params.use_mlock = base_model_params.use_mlock; draft_model_params.use_direct_io = base_model_params.use_direct_io; draft_model_params.n_gpu_layers = draft_gpulayers; //layers offload the speculative model. + draft_model_params.devices = base_model_params.devices; draft_ctx_params.n_ctx = base_ctx_params.n_ctx; draft_ctx_params.offload_kqv = base_ctx_params.offload_kqv; draft_model_params.main_gpu = base_model_params.main_gpu; @@ -2375,6 +2376,19 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in model_params.use_direct_io = false; //no direct io for now until stable model_params.n_gpu_layers = inputs.gpulayers; + //set device overrides if needed + std::vector devices_override; + std::string dev_override_str = inputs.devices_override; + if(dev_override_str!="") + { + devices_override = kcpp_parse_device_list(dev_override_str); + if(devices_override.size()>0) + { + printf("\nOverriding with %d devices...\n",devices_override.size()-1); + model_params.devices = devices_override.data(); + } + } + #if defined(GGML_USE_CLBLAST) if(file_format==FileFormat::GGUF_GENERIC && model_params.n_gpu_layers>0) { @@ -2483,11 +2497,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } //handle override tensor std::string tensoroverrides = inputs.override_tensors; - // if(file_format_meta.model_architecture==GGUFArch::ARCH_GEMMA3N) - // { - // std::string forced = "per_layer_token_embd.weight=CPU"; //this tensor on gpu is problematic on unsloth q4_0 - // tensoroverrides = (tensoroverrides=="" ? forced: (forced+","+tensoroverrides)); - // } + if(ggml_backend_dev_count()>1 && inputs.moecpu>0) { std::string toadd = ""; diff --git a/koboldcpp.py b/koboldcpp.py index 221d50187..5bee66c08 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -226,6 +226,7 @@ class load_model_inputs(ctypes.Structure): ("smartcacheslots", ctypes.c_int), ("pipelineparallel", ctypes.c_bool), ("lora_multiplier", ctypes.c_float), + ("devices_override", ctypes.c_char_p), ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] @@ -317,6 +318,7 @@ class sd_load_model_inputs(ctypes.Structure): ("photomaker_filename", ctypes.c_char_p), ("img_hard_limit", ctypes.c_int), ("img_soft_limit", ctypes.c_int), + ("devices_override", ctypes.c_char_p), ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] @@ -361,6 +363,7 @@ class whisper_load_model_inputs(ctypes.Structure): ("clblast_info", ctypes.c_int), ("kcpp_main_gpu", ctypes.c_int), ("vulkan_info", ctypes.c_char_p), + ("devices_override", ctypes.c_char_p), ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] @@ -385,6 +388,7 @@ class tts_load_model_inputs(ctypes.Structure): ("gpulayers", ctypes.c_int), ("flash_attention", ctypes.c_bool), ("ttsmaxlen", ctypes.c_int), + ("devices_override", ctypes.c_char_p), ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] @@ -411,6 +415,7 @@ class embeddings_load_model_inputs(ctypes.Structure): ("flash_attention", ctypes.c_bool), ("use_mmap", ctypes.c_bool), ("embeddingsmaxctx", ctypes.c_int), + ("devices_override", ctypes.c_char_p), ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] @@ -833,6 +838,7 @@ def set_backend_props(inputs): inputs.vulkan_info = "".encode("UTF-8") # set universal flags + inputs.devices_override = (args.device if args.device else "").encode("UTF-8") inputs.quiet = args.quiet inputs.debugmode = args.debugmode inputs.executable_path = (getdirpath()+"/").encode("UTF-8") @@ -8832,6 +8838,7 @@ if __name__ == '__main__': advparser.add_argument("--gendefaults", metavar=('{"parameter":"value",...}'), help="Sets extra default parameters for some fields in API requests, as a JSON string.", default="") advparser.add_argument("--gendefaultsoverwrite", help="Allow the gendefaults parameters to overwrite the original value in API payloads.", action='store_true') advparser.add_argument("--mcpfile", metavar=('[mcp json file]'), help="Specify path to mcp.json which contains the Cladue Desktop compatible MCP server config.", default="") + advparser.add_argument("--device", "-dev", metavar=(''), help="Set llama.cpp compatible device selection override. Comma separated. Overrides normal device choices.", default="") hordeparsergroup = parser.add_argument_group('Horde Worker Commands') hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="") diff --git a/otherarch/embeddings_adapter.cpp b/otherarch/embeddings_adapter.cpp index 71b1d05e3..218724e1a 100644 --- a/otherarch/embeddings_adapter.cpp +++ b/otherarch/embeddings_adapter.cpp @@ -102,6 +102,12 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs) } } const char* existingenv = getenv("GGML_VK_VISIBLE_DEVICES"); + std::vector devices_override; + std::string dev_override_str = inputs.devices_override; + if(dev_override_str!="") + { + devices_override = kcpp_parse_device_list(dev_override_str); + } if(!existingenv && vulkan_info_str!="") { ttsvulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str; @@ -125,6 +131,12 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs) model_params.main_gpu = kcpp_parseinfo_maindevice; model_params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; + if(devices_override.size()>0) + { + printf("\nOverriding with %d devices...\n",devices_override.size()-1); + model_params.devices = devices_override.data(); + } + llama_model * embeddingsmodel = llama_model_load_from_file(modelfile.c_str(), model_params); const int n_ctx_train = llama_model_n_ctx_train(embeddingsmodel); diff --git a/otherarch/tts_adapter.cpp b/otherarch/tts_adapter.cpp index 42051b216..029b026b8 100644 --- a/otherarch/tts_adapter.cpp +++ b/otherarch/tts_adapter.cpp @@ -637,6 +637,12 @@ bool ttstype_load_model(const tts_load_model_inputs inputs) } } const char* existingenv = getenv("GGML_VK_VISIBLE_DEVICES"); + std::vector devices_override; + std::string dev_override_str = inputs.devices_override; + if(dev_override_str!="") + { + devices_override = kcpp_parse_device_list(dev_override_str); + } if(!existingenv && vulkan_info_str!="") { ttsvulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str; @@ -699,6 +705,12 @@ bool ttstype_load_model(const tts_load_model_inputs inputs) tts_ctx_params.flash_attn_type = (inputs.flash_attention?LLAMA_FLASH_ATTN_TYPE_ENABLED:LLAMA_FLASH_ATTN_TYPE_DISABLED); tts_ctx_params.kv_unified = true; + if(devices_override.size()>0) + { + printf("\nOverriding with %d devices...\n",devices_override.size()-1); + tts_model_params.devices = devices_override.data(); + } + llama_model * ttcmodel = llama_model_load_from_file(modelfile_ttc.c_str(), tts_model_params); ttc_ctx = llama_init_from_model(ttcmodel, tts_ctx_params); diff --git a/otherarch/utils.cpp b/otherarch/utils.cpp index 60df611f0..df73f8e83 100644 --- a/otherarch/utils.cpp +++ b/otherarch/utils.cpp @@ -709,4 +709,43 @@ bool kcpp_decode_audio_from_buf(const unsigned char * buf_in, size_t len, int ta ma_decoder_uninit(&decoder); return true; +} + +static std::vector kcpp_string_split(const std::string & input, char separator) +{ + std::vector parts; + size_t begin_pos = 0; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(begin_pos, separator_pos - begin_pos); + parts.emplace_back(part); + begin_pos = separator_pos + 1; + separator_pos = input.find(separator, begin_pos); + } + parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + return parts; +} + +//for llama.cpp style device overrides e.g. --device Vulkan0,Vulkan1 +std::vector kcpp_parse_device_list(const std::string & value) { + std::vector devices; + auto dev_names = kcpp_string_split(value, ','); + if (dev_names.empty()) { + printf("\nkcpp_parse_device_list error: no devices specified\n"); + return std::vector(); + } + if (dev_names.size() == 1 && dev_names[0] == "none") { + return std::vector(); + } else { + for (const auto & device : dev_names) { + auto * dev = ggml_backend_dev_by_name(device.c_str()); + if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + printf("\nkcpp_parse_device_list error: invalid device: %s\n",device.c_str()); + return std::vector(); + } + devices.push_back(dev); + } + devices.push_back(nullptr); + } + return devices; } \ No newline at end of file diff --git a/otherarch/utils.h b/otherarch/utils.h index bf3b0ab83..0309c9840 100644 --- a/otherarch/utils.h +++ b/otherarch/utils.h @@ -69,6 +69,8 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector< std::vector split_string(const std::string& input, const std::string& separator); bool kcpp_decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector & pcmf32_mono); +std::vector kcpp_parse_device_list(const std::string & value); + //duplcated and modified from llava_embd_batch struct kcpp_embd_batch { std::vector pos;