This commit is contained in:
Concedo 2024-05-14 19:17:01 +08:00
parent 4807b66907
commit 5d15f8f76a
13 changed files with 70 additions and 13 deletions

View file

@ -102,6 +102,7 @@ class generation_outputs(ctypes.Structure):
class sd_load_model_inputs(ctypes.Structure):
_fields_ = [("model_filename", ctypes.c_char_p),
("executable_path", ctypes.c_char_p),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
("vulkan_info", ctypes.c_char_p),
@ -521,6 +522,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename):
global args
inputs = sd_load_model_inputs()
inputs.debugmode = args.debugmode
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
inputs.model_filename = model_filename.encode("UTF-8")
thds = args.threads
quant = 0
@ -1714,6 +1716,10 @@ def show_new_gui():
password_var = ctk.StringVar()
sd_model_var = ctk.StringVar()
sd_lora_var = ctk.StringVar()
sd_loramult_var = ctk.StringVar(value="1.0")
sd_vae_var = ctk.StringVar()
sd_vaeauto_var = ctk.IntVar(value=0)
sd_clamped_var = ctk.IntVar(value=0)
sd_threads_var = ctk.StringVar(value=str(default_threads))
sd_quant_var = ctk.IntVar(value=0)
@ -1782,7 +1788,7 @@ def show_new_gui():
def makefileentry(parent, text, searchtext, var, row=0, width=200, filetypes=[], onchoosefile=None, singlerow=False, tooltiptxt=""):
makelabel(parent, text, row,0,tooltiptxt)
label = makelabel(parent, text, row,0,tooltiptxt)
def getfilename(var, text):
initialDir = os.path.dirname(var.get())
initialDir = initialDir if os.path.isdir(initialDir) else None
@ -1799,7 +1805,7 @@ def show_new_gui():
else:
entry.grid(row=row+1, column=0, padx=8, stick="nw")
button.grid(row=row+1, column=1, stick="nw")
return
return label, entry, button
# decided to follow yellowrose's and kalomaze's suggestions, this function will automatically try to determine GPU identifiers
# run in new thread so it doesnt block. does not return anything, instead overwrites specific values and redraws GUI
@ -2200,11 +2206,26 @@ def show_new_gui():
togglehorde(1,1,1)
# Image Gen Tab
images_tab = tabcontent["Image Gen"]
makefileentry(images_tab, "Stable Diffusion Model (safetensors/gguf):", "Select Stable Diffusion Model File", sd_model_var, 1, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")], tooltiptxt="Select a .safetensors or .gguf Stable Diffusion model file on disk to be loaded.")
makecheckbox(images_tab, "Clamped Mode (Limit Resolution)", sd_clamped_var, 4,tooltiptxt="Limit generation steps and resolution settings for shared use.")
makelabelentry(images_tab, "Image threads:" , sd_threads_var, 6, 50,"How many threads to use during image generation.\nIf left blank, uses same value as threads.")
makelabelentry(images_tab, "Image Threads:" , sd_threads_var, 6, 50,"How many threads to use during image generation.\nIf left blank, uses same value as threads.")
makecheckbox(images_tab, "Compress Weights (Saves Memory)", sd_quant_var, 8,tooltiptxt="Quantizes the SD model weights to save memory. May degrade quality.")
makefileentry(images_tab, "Image LoRA:", "Select SD lora file",sd_lora_var, 10 ,filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded.")
makelabelentry(images_tab, "Image LoRA Multiplier:" , sd_loramult_var, 12, 50,"What mutiplier value to apply the SD LoRA with.")
sdvaeitem1,sdvaeitem2,sdvaeitem3 = makefileentry(images_tab, "Image VAE:", "Select SD VAE file",sd_vae_var, 14, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD VAE file to be loaded.")
def toggletaesd(a,b,c):
if sd_vaeauto_var.get()==1:
sdvaeitem1.grid_forget()
sdvaeitem2.grid_forget()
sdvaeitem3.grid_forget()
else:
sdvaeitem1.grid(row=14,column=0,padx=8,stick="nw")
sdvaeitem2.grid(row=15,column=0,padx=8,stick="nw")
sdvaeitem3.grid(row=15,column=1,stick="nw")
makecheckbox(images_tab, "Use TAE SD VAE", sd_vaeauto_var, 16,command=toggletaesd,tooltiptxt="Replace VAE with TAESD.")
# launch
def guilaunch():
@ -2308,6 +2329,19 @@ def show_new_gui():
args.sdthreads = (0 if sd_threads_var.get()=="" else int(sd_threads_var.get()))
if sd_quant_var.get()==1:
args.sdquant = True
if sd_vaeauto_var.get()==1:
args.sdvaeauto = True
args.sdvae = ""
else:
args.sdvaeauto = False
args.sdvae = ""
if sd_vae_var.get() != "":
args.sdvae = sd_vae_var.get()
if sd_lora_var.get() != "":
args.sdlora = sd_lora_var.get()
args.sdloramult = float(sd_loramult_var.get())
else:
args.sdlora = ""
def import_vars(dict):
dict = convert_outdated_args(dict)
@ -2442,6 +2476,10 @@ def show_new_gui():
sd_clamped_var.set(1 if ("sdclamped" in dict and dict["sdclamped"]) else 0)
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
def save_config():
file_type = [("KoboldCpp Settings", "*.kcpps")]