diff --git a/expose.h b/expose.h index c27ecebbd..d34f06bfd 100644 --- a/expose.h +++ b/expose.h @@ -166,6 +166,8 @@ struct sd_load_model_inputs const int threads = 0; const int quant = 0; const bool flash_attention = false; + const bool diffusion_conv_direct = false; + const bool vae_conv_direct = false; const bool taesd = false; const int tiled_vae_threshold = 0; const char * t5xxl_filename = nullptr; diff --git a/koboldcpp.py b/koboldcpp.py index c2e32faf4..69ace6768 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure): ("threads", ctypes.c_int), ("quant", ctypes.c_int), ("flash_attention", ctypes.c_bool), + ("diffusion_conv_direct", ctypes.c_bool), + ("vae_conv_direct", ctypes.c_bool), ("taesd", ctypes.c_bool), ("tiled_vae_threshold", ctypes.c_int), ("t5xxl_filename", ctypes.c_char_p), @@ -1637,6 +1639,19 @@ def generate(genparams, stream_flag=False): outstr = outstr[:sindex] return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens} +sd_convdirect_choices = ['off', 'vaeonly', 'full'] + +def sd_convdirect_option(value): + if not value: + value = '' + value = value.lower() + if value in ['disabled', 'disable', 'none', 'off', '0', '']: + return 'off' + elif value in ['vae', 'vaeonly']: + return 'vaeonly' + elif value in ['enabled', 'enable', 'on', 'full']: + return 'full' + raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.") def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename): global args @@ -1654,7 +1669,10 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl inputs.threads = thds inputs.quant = quant - inputs.flash_attention = args.flashattention + inputs.flash_attention = args.sdflashattention + sdconvdirect = sd_convdirect_option(args.sdconvdirect) + inputs.diffusion_conv_direct = sdconvdirect == 'full' + inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'full'] inputs.taesd = True if args.sdvaeauto else False inputs.tiled_vae_threshold = args.sdtiledvae inputs.vae_filename = vae_filename.encode("UTF-8") @@ -4568,8 +4586,10 @@ def show_gui(): sd_clipl_var = ctk.StringVar() sd_clipg_var = ctk.StringVar() sd_photomaker_var = ctk.StringVar() + sd_flash_attention_var = ctk.IntVar(value=0) sd_vaeauto_var = ctk.IntVar(value=0) sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold)) + sd_convdirect_var = ctk.StringVar(value='disabled') sd_clamped_var = ctk.StringVar(value="0") sd_clamped_soft_var = ctk.StringVar(value="0") sd_threads_var = ctk.StringVar(value=str(default_threads)) @@ -4634,6 +4654,18 @@ def show_gui(): temp.bind("", hide_tooltip) return temp + def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8): + label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx) + label=None + combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly") + if command is not None and variable is not None: + variable.trace_add("write", command) + combo.grid(row=row,column=0, padx=padx, sticky="nw") + if tooltiptxt!="": + combo.bind("", lambda event: show_tooltip(event, tooltiptxt)) + combo.bind("", hide_tooltip) + return combo, label + def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8): temp = ctk.CTkLabel(parent, text=text) temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan) @@ -5328,8 +5360,10 @@ def show_gui(): sdvaeitem1.grid() sdvaeitem2.grid() sdvaeitem3.grid() - makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.") + makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.") + makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices) makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.") + makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.") # audio tab audio_tab = tabcontent["Audio"] @@ -5566,6 +5600,8 @@ def show_gui(): if sd_model_var.get() != "": args.sdmodel = sd_model_var.get() + if sd_flash_attention_var.get()==1: + args.sdflashattention = True args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get())) args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get())) args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get())) @@ -5578,6 +5614,7 @@ def show_gui(): args.sdvae = "" if sd_vae_var.get() != "": args.sdvae = sd_vae_var.get() + args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get()) if sd_t5xxl_var.get() != "": args.sdt5xxl = sd_t5xxl_var.get() if sd_clipl_var.get() != "": @@ -5798,6 +5835,8 @@ def show_gui(): sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0) sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads)) sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0) + sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0) + sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect"))) sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "") sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "") sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "") @@ -7600,6 +7639,8 @@ if __name__ == '__main__': sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="") sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="") sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="") + sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true') + sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'off' (default) to disable, 'full' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0]) sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group() sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="") sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true') diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index d967bbf0e..ba10e881a 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -99,6 +99,8 @@ struct SDParams { bool clip_on_cpu = false; bool vae_on_cpu = false; bool diffusion_flash_attn = false; + bool diffusion_conv_direct = false; + bool vae_conv_direct = false; bool canny_preprocess = false; bool color = false; int upscale_repeats = 1; @@ -211,6 +213,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { { printf("Flash Attention is enabled\n"); } + if(inputs.diffusion_conv_direct) + { + printf("Conv2D Direct for diffusion model is enabled\n"); + } + if(inputs.vae_conv_direct) + { + printf("Conv2D Direct for VAE model is enabled\n"); + } if(inputs.quant) { printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n"); @@ -246,6 +256,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0); sd_params->n_threads = inputs.threads; //if -1 use physical cores sd_params->diffusion_flash_attn = inputs.flash_attention; + sd_params->diffusion_conv_direct = inputs.diffusion_conv_direct; + sd_params->vae_conv_direct = inputs.vae_conv_direct; sd_params->input_path = ""; //unused sd_params->batch_count = 1; sd_params->vae_path = vaefilename; @@ -316,6 +328,8 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { params.keep_control_net_on_cpu = sd_params->control_net_cpu; params.keep_vae_on_cpu = sd_params->vae_on_cpu; params.diffusion_flash_attn = sd_params->diffusion_flash_attn; + params.diffusion_conv_direct = sd_params->diffusion_conv_direct; + params.vae_conv_direct = sd_params->vae_conv_direct; params.chroma_use_dit_mask = sd_params->chroma_use_dit_mask; params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask; params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;