sd: support for CLIP and VAE on different devices (#2184)

* sd: generalize internal interfaces to place generation on CPU

* sd: backend support for multi-device selection

* sd: frontend support for multi-device selection

* add deprecated flags to avoid breaking old cli args

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
Wagner Bruna 2026-05-19 10:51:23 -03:00 committed by GitHub
parent 7232096c11
commit 592d12d0a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 99 additions and 41 deletions

View file

@ -183,14 +183,14 @@ struct sd_load_model_inputs
{
const char * model_filename = nullptr;
const char * executable_path = nullptr;
const int kcpp_main_gpu = -1;
const int kcpp_main_device = -1;
const int threads = 0;
const int quant = 0;
const bool flash_attention = false;
const bool offload_cpu = false;
const bool use_mmap = false;
const bool vae_cpu = false;
const bool clip_cpu = false;
const int kcpp_vae_device = -1;
const int kcpp_clip_device = -1;
const bool diffusion_conv_direct = false;
const bool vae_conv_direct = false;
const bool taesd = false;

View file

@ -57,6 +57,8 @@ net_save_slots = 12
savestate_limit_default = 5
savestate_limit = 0 #savestate slots start at 0, only set when load model
default_vae_tile_threshold = 768
default_sdvaedevice = 'main'
default_sdclipdevice = 'CPU'
default_native_ctx = 16384
default_genlen = 1024
overridekv_max = 16
@ -363,14 +365,14 @@ class generation_outputs(ctypes.Structure):
class sd_load_model_inputs(ctypes.Structure):
_fields_ = [("model_filename", ctypes.c_char_p),
("executable_path", ctypes.c_char_p),
("kcpp_main_gpu", ctypes.c_int),
("kcpp_main_device", ctypes.c_int),
("threads", ctypes.c_int),
("quant", ctypes.c_int),
("flash_attention", ctypes.c_bool),
("offload_cpu", ctypes.c_bool),
("use_mmap", ctypes.c_bool),
("vae_cpu", ctypes.c_bool),
("clip_cpu", ctypes.c_bool),
("kcpp_vae_device", ctypes.c_int),
("kcpp_clip_device", ctypes.c_int),
("diffusion_conv_direct", ctypes.c_bool),
("vae_conv_direct", ctypes.c_bool),
("taesd", ctypes.c_bool),
@ -2399,6 +2401,33 @@ def sd_quant_option(value):
except Exception:
return 0
sd_device_choices = ['CPU', 'main', '1', '2', '3', '4']
def sd_get_device_number(name, offset=0):
if name is None: # default handling should be done elsewhere
return None
if not name:
return -1
name = name.lower()
aliases = {"cpu": -2, "gpu": -1, "": -1, "main": -1, "default": -1}
if name in aliases:
return aliases[name]
return tryparseint(name, -1) + offset
def sd_get_device_name(value, offset=0):
if value <= -2:
return "CPU"
if value == -1:
return "main"
return value + offset
def sd_resolve_device(name, default_=-1, offset=0):
if name is None:
name = default_
if isinstance(name, int):
name = str(max(name, -2))
return sd_get_device_number(name, offset=offset)
def sd_load_model(model_filename,vae_filename,t5xxl_filename,clip1_filename,clip2_filename,photomaker_filename,upscaler_filename):
global args
inputs = sd_load_model_inputs()
@ -2415,8 +2444,8 @@ def sd_load_model(model_filename,vae_filename,t5xxl_filename,clip1_filename,clip
inputs.flash_attention = args.sdflashattention
inputs.offload_cpu = args.sdoffloadcpu
inputs.use_mmap = args.usemmap
inputs.vae_cpu = args.sdvaecpu
inputs.clip_cpu = False if args.sdclipgpu else True
inputs.kcpp_vae_device = sd_resolve_device(args.sdvaedevice, default_sdvaedevice)
inputs.kcpp_clip_device = sd_resolve_device(args.sdclipdevice, default_sdclipdevice)
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
inputs.diffusion_conv_direct = sdconvdirect == 'full'
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'full']
@ -2444,7 +2473,7 @@ def sd_load_model(model_filename,vae_filename,t5xxl_filename,clip1_filename,clip
inputs.img_hard_limit = args.sdclamped
inputs.img_soft_limit = args.sdclampedsoft
inputs = set_backend_props(inputs)
inputs.kcpp_main_gpu = args.sdmaingpu
inputs.kcpp_main_device = sd_resolve_device(args.sdmaingpu, 'main')
ret = handle.sd_load_model(inputs)
return ret
@ -7620,8 +7649,8 @@ def show_gui():
sd_upscaler_var = ctk.StringVar()
sd_flash_attention_var = ctk.IntVar(value=0)
sd_offload_cpu_var = ctk.IntVar(value=0)
sd_vae_cpu_var = ctk.IntVar(value=0)
sd_clip_gpu_var = ctk.IntVar(value=0)
sd_vae_device_var = ctk.StringVar(value="main")
sd_clip_device_var = ctk.StringVar(value="CPU")
sd_runtime_loras_var = ctk.IntVar(value=0)
sd_vaeauto_var = ctk.IntVar(value=0)
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
@ -8486,9 +8515,9 @@ def show_gui():
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=(220), padx=(310), width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=(144),singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
makecheckbox(images_tab, "SD Flash Attention", sd_flash_attention_var, 44,padx=(230), tooltiptxt="Enable Flash Attention for image diffusion. May save memory or improve performance.")
makecheckbox(images_tab, "Model CPU Offload", sd_offload_cpu_var, 50,padx=8, tooltiptxt="Offload image weights in RAM to save VRAM, swap into VRAM when needed.")
makecheckbox(images_tab, "VAE on CPU", sd_vae_cpu_var, 50,padx=(160), tooltiptxt="Force VAE to CPU only for image generation.")
makecheckbox(images_tab, "CLIP on GPU", sd_clip_gpu_var, 50,padx=(280), tooltiptxt="Put CLIP and T5 to GPU for image generation. Otherwise, CLIP will use CPU.")
makecheckbox(images_tab, "Model Offload", sd_offload_cpu_var, 50,padx=8, tooltiptxt="Offload image weights in RAM to save VRAM, swap into VRAM when needed.")
makelabelcombobox(images_tab, "VAE dev:", sd_vae_device_var, 50,labelpadx=(140),padx=(200), width=70, tooltiptxt="Change VAE device for image generation.", values=sd_device_choices)
makelabelcombobox(images_tab, "CLIP dev:", sd_clip_device_var, 50,labelpadx=(280),padx=340, width=70, tooltiptxt="Change CLIP / T5 / LLM device for image generation.", values=sd_device_choices)
# audio tab
audio_tab = tabcontent["Audio"]
@ -8803,8 +8832,8 @@ def show_gui():
args.sdmodel = sd_model_var.get() if sd_model_var.get() != "" else ""
args.sdflashattention = True if sd_flash_attention_var.get()==1 else False
args.sdoffloadcpu = True if sd_offload_cpu_var.get()==1 else False
args.sdvaecpu = True if sd_vae_cpu_var.get()==1 else False
args.sdclipgpu = True if sd_clip_gpu_var.get()==1 else False
args.sdvaedevice = sd_resolve_device(sd_vae_device_var.get(), -1)
args.sdclipdevice = sd_resolve_device(sd_clip_device_var.get(), -1)
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@ -9092,8 +9121,8 @@ def show_gui():
sd_quant_var.set(sd_quant_choices[(mydict["sdquant"] if ("sdquant" in mydict and mydict["sdquant"]>=0 and mydict["sdquant"]<len(sd_quant_choices)) else 0)])
sd_flash_attention_var.set(1 if ("sdflashattention" in mydict and mydict["sdflashattention"]) else 0)
sd_offload_cpu_var.set(1 if ("sdoffloadcpu" in mydict and mydict["sdoffloadcpu"]) else 0)
sd_vae_cpu_var.set(1 if ("sdvaecpu" in mydict and mydict["sdvaecpu"]) else 0)
sd_clip_gpu_var.set(1 if ("sdclipgpu" in mydict and mydict["sdclipgpu"]) else 0)
sd_vae_device_var.set(sd_get_device_name(sd_resolve_device(mydict.get("sdvaedevice"), default_sdvaedevice), 1))
sd_clip_device_var.set(sd_get_device_name(sd_resolve_device(mydict.get("sdclipdevice"), default_sdclipdevice), 1))
sd_convdirect_var.set(sd_convdirect_option(mydict.get("sdconvdirect")))
sd_vae_var.set(mydict["sdvae"] if ("sdvae" in mydict and mydict["sdvae"]) else "")
sd_t5xxl_var.set(mydict["sdt5xxl"] if ("sdt5xxl" in mydict and mydict["sdt5xxl"]) else "")
@ -9604,6 +9633,10 @@ def convert_invalid_args(args):
dict["sdlora"] = sanitize_lora_list(dict["sdlora"])
if "sdloramult" in dict:
dict["sdloramult"] = sanitize_lora_multipliers(dict["sdloramult"])
if "sdclipgpu" in dict and dict.get("sdclipdevice") is None:
dict["sdclipdevice"] = sd_get_device_number("main" if dict["sdclipgpu"] else "CPU")
if "sdvaecpu" in dict and dict.get("sdvaedevice") is None:
dict["sdvaedevice"] = sd_get_device_number("CPU" if dict["sdvaecpu"] else "main")
return args
def setuptunnel(global_memory, has_sd, has_music):
@ -11612,8 +11645,8 @@ if __name__ == '__main__':
sdparsergroup.add_argument("--sdupscaler", metavar=('[filename]'), help="You can use ESRGAN as an upscaling model to resize images. Leave blank if unused.", default="")
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
sdparsergroup.add_argument("--sdoffloadcpu", help="Offload image weights in RAM to save VRAM, swap into VRAM when needed.", action='store_true')
sdparsergroup.add_argument("--sdvaecpu", help="Force VAE to CPU only for image generation.", action='store_true')
sdparsergroup.add_argument("--sdclipgpu", help="Put CLIP and T5 to GPU for image generation. Otherwise, CLIP will use CPU.", action='store_true')
sdparsergroup.add_argument("--sdvaedevice", help=f"VAE device for image generation. GPU index, -1 or 'main' for the main GPU, or 'CPU' (default: {default_sdvaedevice}).", type=sd_get_device_number, default=None)
sdparsergroup.add_argument("--sdclipdevice", help=f"CLIP / T5 / LLM device for image generation. GPU index, -1 or 'main' for the main GPU, or 'CPU' (default: {default_sdclipdevice}).", type=sd_get_device_number, default=None)
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'off' (default) to disable, 'full' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
@ -11623,7 +11656,7 @@ if __name__ == '__main__':
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify image generation LoRAs safetensors models to be applied. Multiple LoRAs are accepted.", nargs='+')
sdparsergroup.add_argument("--sdloramult", metavar=('[amounts]'), help="Multipliers for the image LoRA model to be applied.", type=float, nargs='+', default=[1.0])
sdparsergroup.add_argument("--sdtiledvae", metavar=('[maxres]'), help="Adjust the automatic VAE tiling trigger for images above this size. 0 disables vae tiling.", type=int, default=default_vae_tile_threshold)
sdparsergroup.add_argument("--sdmaingpu", metavar=('[Device ID]'), help="If specified, Image Generation weights will be placed on the selected GPU index", type=int, default=-1)
sdparsergroup.add_argument("--sdmaingpu", metavar=('[Device ID]'), help="If specified, Image Generation weights will be placed on the selected GPU index. GPU index, -1 or 'main' for the main GPU, or 'CPU' (default: 'main')", type=sd_get_device_number, default=None)
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper .bin model to enable Speech-To-Text transcription.", default="")
@ -11665,6 +11698,8 @@ if __name__ == '__main__':
compatgroup3.add_argument("--nommap","--no-mmap", help=argparse.SUPPRESS, action='store_true')
deprecatedgroup.add_argument("--pipelineparallel", help=argparse.SUPPRESS, action='store_true') #changed to nopipelineparallel
deprecatedgroup.add_argument("--sdnotile", help=argparse.SUPPRESS, action='store_true') # legacy option, see sdtiledvae
deprecatedgroup.add_argument("--sdvaecpu", help=argparse.SUPPRESS, action='store_true') # legacy option, see sdvaedevice
deprecatedgroup.add_argument("--sdclipgpu", help=argparse.SUPPRESS, action='store_true') # legacy option, see sdclipgpu
deprecatedgroup.add_argument("--forceversion", help=argparse.SUPPRESS, action='store_true') #no longer used
deprecatedgroup.add_argument("--sdgendefaults", help=argparse.SUPPRESS, action='store_true') # legacy option, see gendefaults
deprecatedgroup.add_argument("--flashattention","--flash-attn","-fa", help=argparse.SUPPRESS, action='store_true') #flash attention now default on

View file

@ -120,7 +120,6 @@ struct SDParams {
float eta = -1.0f;
float strength = 0.75f;
int64_t seed = 42;
bool clip_on_cpu = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
@ -261,19 +260,33 @@ std::string load_umt5_tokenizer_json()
return umt5str;
}
static const char * get_main_gpu_name(int value)
static std::string get_device_override(int value, const char * module = nullptr)
{
if (value < 0)
return "";
size_t gpu_index = static_cast<size_t>(value);
if (gpu_index >= ggml_backend_dev_count()) {
printf("\nWARNING: device %zu doesn't exist, falling back to the default\n", gpu_index);
return "";
std::string device_name;
if (value <= -2) {
device_name = "CPU";
} else if (value >= 0) {
size_t gpu_index = static_cast<size_t>(value);
if (gpu_index >= ggml_backend_dev_count()) {
printf("\nWARNING: device %zu doesn't exist, falling back to default for %s\n",
gpu_index,
module ? module : "the main device");
} else {
auto dev = ggml_backend_dev_get(gpu_index);
device_name = ggml_backend_dev_name(dev);
}
}
auto dev = ggml_backend_dev_get(gpu_index);
auto name = ggml_backend_dev_name(dev);
printf("Setting %s as image generation device\n", name);
return name;
std::string result;
if (device_name == "") {
result = ""; // no override: sdcpp will use the main device
} else if (module) {
printf("Selecting %s as %s image generation device\n", device_name.c_str(), module);
result = std::string{","} + module + "=" + device_name;
} else {
printf("Selecting %s as the main image generation device\n", device_name.c_str());
result = device_name;
}
return result;
}
bool sdtype_load_model(const sd_load_model_inputs inputs) {
@ -299,8 +312,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
cfg_square_limit = inputs.img_soft_limit;
printf("\nImageGen Init - Load Model: %s\n",inputs.model_filename);
//kcpp allow gpu id override
std::string main_gpu_name = get_main_gpu_name(inputs.kcpp_main_gpu);
std::string backends = get_device_override(inputs.kcpp_main_device);
int lora_apply_mode = LORA_APPLY_AT_RUNTIME;
bool lora_dynamic = false;
@ -448,11 +460,22 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
params.diffusion_conv_direct = sd_params->diffusion_conv_direct;
params.vae_conv_direct = sd_params->vae_conv_direct;
params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask;
params.offload_params_to_cpu = inputs.offload_cpu;
params.enable_mmap = inputs.use_mmap;
params.backend = main_gpu_name.c_str();
params.keep_vae_on_cpu = inputs.vae_cpu;
params.keep_clip_on_cpu = inputs.clip_cpu;
// the _cpu flags are only used if the backend string is empty, but
// we always set both for consistency
params.offload_params_to_cpu = inputs.offload_cpu;
params.params_backend = inputs.offload_cpu ? "CPU" : "";
params.keep_vae_on_cpu = (inputs.kcpp_vae_device <= -2);
backends += get_device_override(inputs.kcpp_vae_device, "VAE");
params.keep_clip_on_cpu = (inputs.kcpp_clip_device <= -2);
backends += get_device_override(inputs.kcpp_clip_device, "CLIP");
if (backends.rfind(",", 0) == 0) {
backends = "auto" + backends;
}
params.backend = backends.c_str();
if (inputs.debugmode==1) {
printf("\nSetting sd backend list to \"%s\", params backend list to \"%s\"", params.backend, params.params_backend);
}
params.lora_apply_mode = (lora_apply_mode_t)lora_apply_mode;
// also switches flash attn for the vae and conditioner
@ -516,8 +539,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
params.diffusion_conv_direct,
params.n_threads,
upscale_tile_size,
main_gpu_name.c_str(),
nullptr);
params.backend,
params.params_backend);
if (upscaler_ctx == nullptr) {
printf("\nError: KCPP failed to load upscaler!\n");