KCPP SD: add warn and step restriction., updated lite, handle quant mode

This commit is contained in:
Concedo 2024-03-01 16:41:19 +08:00
parent 3463688a0e
commit 80011ed8aa
4 changed files with 136 additions and 45 deletions

View file

@ -99,6 +99,7 @@ class sd_load_model_inputs(ctypes.Structure):
("cublas_info", ctypes.c_int),
("vulkan_info", ctypes.c_char_p),
("threads", ctypes.c_int),
("quant", ctypes.c_int),
("debugmode", ctypes.c_int)]
class sd_generation_inputs(ctypes.Structure):
@ -484,11 +485,16 @@ def sd_load_model(model_filename):
inputs.debugmode = args.debugmode
inputs.model_filename = model_filename.encode("UTF-8")
thds = args.threads
quant = 0
if len(args.sdconfig) > 2:
sdt = int(args.sdconfig[2])
if sdt > 0:
thds = sdt
if len(args.sdconfig) > 3:
quant = (1 if args.sdconfig[3]=="quant" else 0)
inputs.threads = thds
inputs.quant = quant
inputs = set_backend_props(inputs)
ret = handle.sd_load_model(inputs)
return ret
@ -502,11 +508,16 @@ def sd_generate(genparams):
seed = genparams.get("seed", -1)
sample_method = genparams.get("sampler_name", "euler a")
#clean vars
cfg_scale = (1 if cfg_scale < 1 else (20 if cfg_scale > 20 else cfg_scale))
sample_steps = (1 if sample_steps < 1 else (50 if sample_steps > 50 else sample_steps))
#quick mode
if args.sdconfig and len(args.sdconfig)>1 and args.sdconfig[1]=="quick":
cfg_scale = 1
sample_steps = 7
sample_method = "dpm++ 2m karras"
print("Image generation set to Quick Mode (Low Quality). Step counts, sampler, and cfg scale are fixed.")
inputs = sd_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
@ -1387,7 +1398,8 @@ def show_new_gui():
sd_model_var = ctk.StringVar()
sd_quick_var = ctk.IntVar(value=0)
sd_threads_var = ctk.StringVar()
sd_threads_var = ctk.StringVar(value=str(default_threads))
sd_quant_var = ctk.IntVar(value=0)
def tabbuttonaction(name):
for t in tabcontent:
@ -1866,6 +1878,7 @@ def show_new_gui():
makefileentry(images_tab, "Stable Diffusion Model (f16):", "Select Stable Diffusion Model File", sd_model_var, 1, filetypes=[("*.safetensors","*.safetensors")], tooltiptxt="Select a .safetensors Stable Diffusion model file on disk to be loaded.")
makecheckbox(images_tab, "Quick Mode (Low Quality)", sd_quick_var, 4,tooltiptxt="Force optimal generation settings for speed.")
makelabelentry(images_tab, "Image threads:" , sd_threads_var, 6, 50,"How many threads to use during image generation.\nIf left blank, uses same value as threads.")
makecheckbox(images_tab, "Compress Weights (Slight Memory Saved)", sd_quant_var, 8,tooltiptxt="Quantizes the SD model weights to save memory. May degrade quality.")
# launch
@ -1954,7 +1967,7 @@ def show_new_gui():
else:
args.hordeconfig = None if usehorde_var.get() == 0 else [horde_name_var.get(), horde_gen_var.get(), horde_context_var.get(), horde_apikey_var.get(), horde_workername_var.get()]
args.sdconfig = None if sd_model_var.get() == "" else [sd_model_var.get(), ("quick" if sd_quick_var.get()==1 else "normal"),(int(threads_var.get()) if sd_threads_var.get()=="" else int(sd_threads_var.get()))]
args.sdconfig = None if sd_model_var.get() == "" else [sd_model_var.get(), ("quick" if sd_quick_var.get()==1 else "normal"),(int(threads_var.get()) if sd_threads_var.get()=="" else int(sd_threads_var.get())),("quant" if sd_quant_var.get()==1 else "noquant")]
def import_vars(dict):
if "threads" in dict:
@ -2089,6 +2102,8 @@ def show_new_gui():
sd_quick_var.set(1 if dict["sdconfig"][1]=="quick" else 0)
if len(dict["sdconfig"]) > 2:
sd_threads_var.set(str(dict["sdconfig"][2]))
if len(dict["sdconfig"]) > 3:
sd_quant_var.set(str(dict["sdconfig"][3]))
def save_config():
file_type = [("KoboldCpp Settings", "*.kcpps")]
@ -2865,6 +2880,6 @@ if __name__ == '__main__':
parser.add_argument("--quiet", help="Enable quiet mode, which hides generation inputs and outputs in the terminal. Quiet mode is automatically enabled when running --hordeconfig.", action='store_true')
parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick] [sd_threads]'), nargs='+')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick] [sd_threads] [quant|noquant]'), nargs='+')
main(parser.parse_args(),start_server=True)