Add flash attention and conv2d direct controls for image generation (#1678)

* Add separate flash attention config for image generation

* Add config option for Conv2D Direct
This commit is contained in:
Wagner Bruna 2025-08-20 01:17:57 -03:00 committed by GitHub
parent 35707f4e97
commit 6003e90e50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 59 additions and 2 deletions

View file

@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure):
("threads", ctypes.c_int),
("quant", ctypes.c_int),
("flash_attention", ctypes.c_bool),
("diffusion_conv_direct", ctypes.c_bool),
("vae_conv_direct", ctypes.c_bool),
("taesd", ctypes.c_bool),
("tiled_vae_threshold", ctypes.c_int),
("t5xxl_filename", ctypes.c_char_p),
@ -1637,6 +1639,19 @@ def generate(genparams, stream_flag=False):
outstr = outstr[:sindex]
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
sd_convdirect_choices = ['off', 'vaeonly', 'full']
def sd_convdirect_option(value):
if not value:
value = ''
value = value.lower()
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
return 'off'
elif value in ['vae', 'vaeonly']:
return 'vaeonly'
elif value in ['enabled', 'enable', 'on', 'full']:
return 'full'
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
global args
@ -1654,7 +1669,10 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
inputs.threads = thds
inputs.quant = quant
inputs.flash_attention = args.flashattention
inputs.flash_attention = args.sdflashattention
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
inputs.diffusion_conv_direct = sdconvdirect == 'full'
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'full']
inputs.taesd = True if args.sdvaeauto else False
inputs.tiled_vae_threshold = args.sdtiledvae
inputs.vae_filename = vae_filename.encode("UTF-8")
@ -4568,8 +4586,10 @@ def show_gui():
sd_clipl_var = ctk.StringVar()
sd_clipg_var = ctk.StringVar()
sd_photomaker_var = ctk.StringVar()
sd_flash_attention_var = ctk.IntVar(value=0)
sd_vaeauto_var = ctk.IntVar(value=0)
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
sd_convdirect_var = ctk.StringVar(value='disabled')
sd_clamped_var = ctk.StringVar(value="0")
sd_clamped_soft_var = ctk.StringVar(value="0")
sd_threads_var = ctk.StringVar(value=str(default_threads))
@ -4634,6 +4654,18 @@ def show_gui():
temp.bind("<Leave>", hide_tooltip)
return temp
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
label=None
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
if command is not None and variable is not None:
variable.trace_add("write", command)
combo.grid(row=row,column=0, padx=padx, sticky="nw")
if tooltiptxt!="":
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
combo.bind("<Leave>", hide_tooltip)
return combo, label
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
temp = ctk.CTkLabel(parent, text=text)
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
@ -5328,8 +5360,10 @@ def show_gui():
sdvaeitem1.grid()
sdvaeitem2.grid()
sdvaeitem3.grid()
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
# audio tab
audio_tab = tabcontent["Audio"]
@ -5566,6 +5600,8 @@ def show_gui():
if sd_model_var.get() != "":
args.sdmodel = sd_model_var.get()
if sd_flash_attention_var.get()==1:
args.sdflashattention = True
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
@ -5578,6 +5614,7 @@ def show_gui():
args.sdvae = ""
if sd_vae_var.get() != "":
args.sdvae = sd_vae_var.get()
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
if sd_t5xxl_var.get() != "":
args.sdt5xxl = sd_t5xxl_var.get()
if sd_clipl_var.get() != "":
@ -5798,6 +5835,8 @@ def show_gui():
sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0)
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
@ -7600,6 +7639,8 @@ if __name__ == '__main__':
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'off' (default) to disable, 'full' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')