mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
Add flash attention and conv2d direct controls for image generation (#1678)
* Add separate flash attention config for image generation * Add config option for Conv2D Direct
This commit is contained in:
parent
35707f4e97
commit
6003e90e50
3 changed files with 59 additions and 2 deletions
45
koboldcpp.py
45
koboldcpp.py
|
@ -280,6 +280,8 @@ class sd_load_model_inputs(ctypes.Structure):
|
|||
("threads", ctypes.c_int),
|
||||
("quant", ctypes.c_int),
|
||||
("flash_attention", ctypes.c_bool),
|
||||
("diffusion_conv_direct", ctypes.c_bool),
|
||||
("vae_conv_direct", ctypes.c_bool),
|
||||
("taesd", ctypes.c_bool),
|
||||
("tiled_vae_threshold", ctypes.c_int),
|
||||
("t5xxl_filename", ctypes.c_char_p),
|
||||
|
@ -1637,6 +1639,19 @@ def generate(genparams, stream_flag=False):
|
|||
outstr = outstr[:sindex]
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
|
||||
|
||||
sd_convdirect_choices = ['off', 'vaeonly', 'full']
|
||||
|
||||
def sd_convdirect_option(value):
|
||||
if not value:
|
||||
value = ''
|
||||
value = value.lower()
|
||||
if value in ['disabled', 'disable', 'none', 'off', '0', '']:
|
||||
return 'off'
|
||||
elif value in ['vae', 'vaeonly']:
|
||||
return 'vaeonly'
|
||||
elif value in ['enabled', 'enable', 'on', 'full']:
|
||||
return 'full'
|
||||
raise argparse.ArgumentTypeError(f"Invalid sdconvdirect option \"{value}\". Must be one of {sd_convdirect_choices}.")
|
||||
|
||||
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename,photomaker_filename):
|
||||
global args
|
||||
|
@ -1654,7 +1669,10 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl
|
|||
|
||||
inputs.threads = thds
|
||||
inputs.quant = quant
|
||||
inputs.flash_attention = args.flashattention
|
||||
inputs.flash_attention = args.sdflashattention
|
||||
sdconvdirect = sd_convdirect_option(args.sdconvdirect)
|
||||
inputs.diffusion_conv_direct = sdconvdirect == 'full'
|
||||
inputs.vae_conv_direct = sdconvdirect in ['vaeonly', 'full']
|
||||
inputs.taesd = True if args.sdvaeauto else False
|
||||
inputs.tiled_vae_threshold = args.sdtiledvae
|
||||
inputs.vae_filename = vae_filename.encode("UTF-8")
|
||||
|
@ -4568,8 +4586,10 @@ def show_gui():
|
|||
sd_clipl_var = ctk.StringVar()
|
||||
sd_clipg_var = ctk.StringVar()
|
||||
sd_photomaker_var = ctk.StringVar()
|
||||
sd_flash_attention_var = ctk.IntVar(value=0)
|
||||
sd_vaeauto_var = ctk.IntVar(value=0)
|
||||
sd_tiled_vae_var = ctk.StringVar(value=str(default_vae_tile_threshold))
|
||||
sd_convdirect_var = ctk.StringVar(value='disabled')
|
||||
sd_clamped_var = ctk.StringVar(value="0")
|
||||
sd_clamped_soft_var = ctk.StringVar(value="0")
|
||||
sd_threads_var = ctk.StringVar(value=str(default_threads))
|
||||
|
@ -4634,6 +4654,18 @@ def show_gui():
|
|||
temp.bind("<Leave>", hide_tooltip)
|
||||
return temp
|
||||
|
||||
def makelabelcombobox(parent, text, variable=None, row=0, width=50, command=None, padx=8,tooltiptxt="", values=[], labelpadx=8):
|
||||
label = makelabel(parent, text, row, 0, tooltiptxt, padx=labelpadx)
|
||||
label=None
|
||||
combo = ctk.CTkComboBox(parent, variable=variable, width=width, values=values, state="readonly")
|
||||
if command is not None and variable is not None:
|
||||
variable.trace_add("write", command)
|
||||
combo.grid(row=row,column=0, padx=padx, sticky="nw")
|
||||
if tooltiptxt!="":
|
||||
combo.bind("<Enter>", lambda event: show_tooltip(event, tooltiptxt))
|
||||
combo.bind("<Leave>", hide_tooltip)
|
||||
return combo, label
|
||||
|
||||
def makelabel(parent, text, row, column=0, tooltiptxt="", columnspan=1, padx=8):
|
||||
temp = ctk.CTkLabel(parent, text=text)
|
||||
temp.grid(row=row, column=column, padx=padx, pady=1, stick="nw", columnspan=columnspan)
|
||||
|
@ -5328,8 +5360,10 @@ def show_gui():
|
|||
sdvaeitem1.grid()
|
||||
sdvaeitem2.grid()
|
||||
sdvaeitem3.grid()
|
||||
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
|
||||
makecheckbox(images_tab, "TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 42,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
|
||||
makelabelcombobox(images_tab, "Conv2D Direct:", sd_convdirect_var, row=42, labelpadx=220, padx=310, width=90, tooltiptxt="Use Conv2D Direct operation. May save memory or improve performance.\nMight crash if not supported by the backend.\n", values=sd_convdirect_choices)
|
||||
makelabelentry(images_tab, "VAE Tiling Threshold:", sd_tiled_vae_var, 44, 50, padx=144,singleline=True,tooltip="Enable VAE Tiling for images above this size, to save memory.\nSet to 0 to disable VAE tiling.")
|
||||
makecheckbox(images_tab, "Flash Attention", sd_flash_attention_var, 46, tooltiptxt="Enable Flash Attention for diffusion. May save memory or improve performance.")
|
||||
|
||||
# audio tab
|
||||
audio_tab = tabcontent["Audio"]
|
||||
|
@ -5566,6 +5600,8 @@ def show_gui():
|
|||
if sd_model_var.get() != "":
|
||||
args.sdmodel = sd_model_var.get()
|
||||
|
||||
if sd_flash_attention_var.get()==1:
|
||||
args.sdflashattention = True
|
||||
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
|
||||
args.sdclamped = (0 if int(sd_clamped_var.get())<=0 else int(sd_clamped_var.get()))
|
||||
args.sdclampedsoft = (0 if int(sd_clamped_soft_var.get())<=0 else int(sd_clamped_soft_var.get()))
|
||||
|
@ -5578,6 +5614,7 @@ def show_gui():
|
|||
args.sdvae = ""
|
||||
if sd_vae_var.get() != "":
|
||||
args.sdvae = sd_vae_var.get()
|
||||
args.sdconvdirect = sd_convdirect_option(sd_convdirect_var.get())
|
||||
if sd_t5xxl_var.get() != "":
|
||||
args.sdt5xxl = sd_t5xxl_var.get()
|
||||
if sd_clipl_var.get() != "":
|
||||
|
@ -5798,6 +5835,8 @@ def show_gui():
|
|||
sd_clamped_soft_var.set(int(dict["sdclampedsoft"]) if ("sdclampedsoft" in dict and dict["sdclampedsoft"]) else 0)
|
||||
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
|
||||
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
|
||||
sd_flash_attention_var.set(1 if ("sdflashattention" in dict and dict["sdflashattention"]) else 0)
|
||||
sd_convdirect_var.set(sd_convdirect_option(dict.get("sdconvdirect")))
|
||||
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
|
||||
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
|
||||
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
|
||||
|
@ -7600,6 +7639,8 @@ if __name__ == '__main__':
|
|||
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
|
||||
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
|
||||
sdparsergroup.add_argument("--sdphotomaker", metavar=('[filename]'), help="PhotoMaker is a model that allows face cloning. Specify a PhotoMaker safetensors model which will be applied replacing img2img. SDXL models only. Leave blank if unused.", default="")
|
||||
sdparsergroup.add_argument("--sdflashattention", help="Enables Flash Attention for image generation.", action='store_true')
|
||||
sdparsergroup.add_argument("--sdconvdirect", help="Enables Conv2D Direct. May improve performance or reduce memory usage. Might crash if not supported by the backend. Can be 'off' (default) to disable, 'full' to turn it on for all operations, or 'vaeonly' to enable only for the VAE.", type=sd_convdirect_option, choices=sd_convdirect_choices, default=sd_convdirect_choices[0])
|
||||
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
|
||||
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify an image generation safetensors VAE which replaces the one in the model.", default="")
|
||||
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue