allow custom t5, clipl and clipg

This commit is contained in:
Concedo 2024-11-06 19:05:48 +08:00
parent 3cfc4dc581
commit ccbd630a42
3 changed files with 99 additions and 21 deletions

View file

@ -143,6 +143,9 @@ struct sd_load_model_inputs
const int threads = 0;
const int quant = 0;
const bool taesd = false;
const char * t5xxl_filename = nullptr;
const char * clipl_filename = nullptr;
const char * clipg_filename = nullptr;
const char * vae_filename = nullptr;
const char * lora_filename = nullptr;
const float lora_multiplier = 1.0f;

View file

@ -207,6 +207,9 @@ class sd_load_model_inputs(ctypes.Structure):
("threads", ctypes.c_int),
("quant", ctypes.c_int),
("taesd", ctypes.c_bool),
("t5xxl_filename", ctypes.c_char_p),
("clipl_filename", ctypes.c_char_p),
("clipg_filename", ctypes.c_char_p),
("vae_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p),
("lora_multiplier", ctypes.c_float),
@ -1098,7 +1101,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
def sd_load_model(model_filename,vae_filename,lora_filename):
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename):
global args
inputs = sd_load_model_inputs()
inputs.debugmode = args.debugmode
@ -1120,6 +1123,9 @@ def sd_load_model(model_filename,vae_filename,lora_filename):
inputs.vae_filename = vae_filename.encode("UTF-8")
inputs.lora_filename = lora_filename.encode("UTF-8")
inputs.lora_multiplier = args.sdloramult
inputs.t5xxl_filename = t5xxl_filename.encode("UTF-8")
inputs.clipl_filename = clipl_filename.encode("UTF-8")
inputs.clipg_filename = clipg_filename.encode("UTF-8")
inputs = set_backend_props(inputs)
ret = handle.sd_load_model(inputs)
return ret
@ -2521,6 +2527,9 @@ def show_gui():
sd_lora_var = ctk.StringVar()
sd_loramult_var = ctk.StringVar(value="1.0")
sd_vae_var = ctk.StringVar()
sd_t5xxl_var = ctk.StringVar()
sd_clipl_var = ctk.StringVar()
sd_clipg_var = ctk.StringVar()
sd_vaeauto_var = ctk.IntVar(value=0)
sd_clamped_var = ctk.StringVar(value="0")
sd_threads_var = ctk.StringVar(value=str(default_threads))
@ -2603,6 +2612,10 @@ def show_gui():
entry = ctk.CTkEntry(parent, width, textvariable=var)
button = ctk.CTkButton(parent, 50, text="Browse", command= lambda a=var,b=searchtext:getfilename(a,b))
if singlerow:
if singlecol:
entry.grid(row=row, column=0, padx=(84+8), stick="w")
button.grid(row=row, column=0, padx=(84+width+12), stick="nw")
else:
entry.grid(row=row, column=1, padx=8, stick="w")
button.grid(row=row, column=1, padx=(width+12), stick="nw")
else:
@ -2995,8 +3008,8 @@ def show_gui():
makecheckbox(network_tab, "Quiet Mode", quietmode, 4,tooltiptxt="Prevents all generation related terminal output from being displayed.")
makecheckbox(network_tab, "NoCertify Mode (Insecure)", nocertifymode, 4, 1,tooltiptxt="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.")
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=200 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=200, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=200 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True, singlecol=False,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=200, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True, singlecol=False, tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
makelabelentry(network_tab, "Password: ", password_var, 8, 200,tooltip="Enter a password required to use this instance.\nThis key will be required for all text endpoints.\nImage endpoints are not secured.")
# Horde Tab
@ -3030,13 +3043,13 @@ def show_gui():
# Image Gen Tab
images_tab = tabcontent["Image Gen"]
makefileentry(images_tab, "Stable Diffusion Model (safetensors/gguf):", "Select Stable Diffusion Model File", sd_model_var, 1, width=280, singlecol=False, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")], tooltiptxt="Select a .safetensors or .gguf Stable Diffusion model file on disk to be loaded.")
makelabelentry(images_tab, "Clamped Mode (Limit Resolution)", sd_clamped_var, 4, 50,tooltip="Limit generation steps and resolution settings for shared use.\nSet to 0 to disable, otherwise value is the size limit (min 512px).")
makelabelentry(images_tab, "Image Threads:" , sd_threads_var, 6, 50,tooltip="How many threads to use during image generation.\nIf left blank, uses same value as threads.")
makefileentry(images_tab, "Stable Diffusion Model (safetensors/gguf):", "Select Stable Diffusion Model File", sd_model_var, 1, width=280, singlecol=True, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")], tooltiptxt="Select a .safetensors or .gguf Stable Diffusion model file on disk to be loaded.")
makelabelentry(images_tab, "Clamped Mode (Limit Resolution):", sd_clamped_var, 4, 50, padx=290,singleline=True,tooltip="Limit generation steps and resolution settings for shared use.\nSet to 0 to disable, otherwise value is the size limit (min 512px).")
makelabelentry(images_tab, "Image Threads:" , sd_threads_var, 6, 50,padx=290,singleline=True,tooltip="How many threads to use during image generation.\nIf left blank, uses same value as threads.")
sd_model_var.trace("w", gui_changed_modelfile)
sdloritem1,sdloritem2,sdloritem3 = makefileentry(images_tab, "Image LoRA (Must be non-quant):", "Select SD lora file",sd_lora_var, 10, width=280, singlecol=False, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded.")
sdloritem4,sdloritem5 = makelabelentry(images_tab, "Image LoRA Multiplier:" , sd_loramult_var, 12, 50,tooltip="What mutiplier value to apply the SD LoRA with.")
sdloritem1,sdloritem2,sdloritem3 = makefileentry(images_tab, "Image LoRA (Must be non-quant):", "Select SD lora file",sd_lora_var, 10, width=280, singlecol=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded.")
sdloritem4,sdloritem5 = makelabelentry(images_tab, "Image LoRA Multiplier:" , sd_loramult_var, 12, 50,padx=290,singleline=True,tooltip="What mutiplier value to apply the SD LoRA with.")
def togglesdquant(a,b,c):
if sd_quant_var.get()==1:
sdloritem1.grid_remove()
@ -3045,6 +3058,7 @@ def show_gui():
sdloritem4.grid_remove()
sdloritem5.grid_remove()
else:
if not sdloritem1.grid_info() or not sdloritem2.grid_info() or not sdloritem3.grid_info() or not sdloritem4.grid_info() or not sdloritem5.grid_info():
sdloritem1.grid()
sdloritem2.grid()
sdloritem3.grid()
@ -3053,17 +3067,22 @@ def show_gui():
makecheckbox(images_tab, "Compress Weights (Saves Memory)", sd_quant_var, 8,command=togglesdquant,tooltiptxt="Quantizes the SD model weights to save memory. May degrade quality.")
sd_quant_var.trace("w", changed_gpulayers_estimate)
sdvaeitem1,sdvaeitem2,sdvaeitem3 = makefileentry(images_tab, "Image VAE:", "Select SD VAE file",sd_vae_var, 14, width=280, singlecol=False, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD VAE file to be loaded.")
makefileentry(images_tab, "T5-XXL File:", "Select Optional T5-XXL model file (SD3 or flux)",sd_t5xxl_var, 14, width=280, singlerow=True, filetypes=[("*.safetensors", "*.safetensors")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
makefileentry(images_tab, "Clip-L File:", "Select Optional Clip-L model file (SD3 or flux)",sd_clipl_var, 16, width=280, singlerow=True, filetypes=[("*.safetensors", "*.safetensors")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
makefileentry(images_tab, "Clip-G File:", "Select Optional Clip-G model file (SD3)",sd_clipg_var, 18, width=280, singlerow=True, filetypes=[("*.safetensors", "*.safetensors")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
sdvaeitem1,sdvaeitem2,sdvaeitem3 = makefileentry(images_tab, "Image VAE:", "Select Optional SD VAE file",sd_vae_var, 20, width=280, singlerow=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD VAE file to be loaded.")
def toggletaesd(a,b,c):
if sd_vaeauto_var.get()==1:
sdvaeitem1.grid_remove()
sdvaeitem2.grid_remove()
sdvaeitem3.grid_remove()
else:
if not sdvaeitem1.grid_info() or not sdvaeitem2.grid_info() or not sdvaeitem3.grid_info():
sdvaeitem1.grid()
sdvaeitem2.grid()
sdvaeitem3.grid()
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 16,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 22,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
# audio tab
audio_tab = tabcontent["Audio"]
@ -3246,6 +3265,12 @@ def show_gui():
args.sdvae = ""
if sd_vae_var.get() != "":
args.sdvae = sd_vae_var.get()
if sd_t5xxl_var.get() != "":
args.sdt5xxl = sd_t5xxl_var.get()
if sd_clipl_var.get() != "":
args.sdclipl = sd_clipl_var.get()
if sd_clipg_var.get() != "":
args.sdclipg = sd_clipg_var.get()
if sd_quant_var.get()==1:
args.sdquant = True
args.sdlora = ""
@ -3396,6 +3421,9 @@ def show_gui():
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
sd_clipg_var.set(dict["sdclipg"] if ("sdclipg" in dict and dict["sdclipg"]) else "")
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
@ -4307,6 +4335,9 @@ def main(launch_args,start_server=True):
else:
imglora = ""
imgvae = ""
imgt5xxl = ""
imgclipl = ""
imgclipg = ""
if args.sdlora:
if os.path.exists(args.sdlora):
imglora = os.path.abspath(args.sdlora)
@ -4317,13 +4348,28 @@ def main(launch_args,start_server=True):
imgvae = os.path.abspath(args.sdvae)
else:
print(f"Missing SD VAE model file...")
if args.sdt5xxl:
if os.path.exists(args.sdt5xxl):
imgt5xxl = os.path.abspath(args.sdt5xxl)
else:
print(f"Missing SD T5-XXL model file...")
if args.sdclipl:
if os.path.exists(args.sdclipl):
imgclipl = os.path.abspath(args.sdclipl)
else:
print(f"Missing SD Clip-L model file...")
if args.sdclipg:
if os.path.exists(args.sdclipg):
imgclipg = os.path.abspath(args.sdclipg)
else:
print(f"Missing SD Clip-G model file...")
imgmodel = os.path.abspath(imgmodel)
fullsdmodelpath = imgmodel
friendlysdmodelname = os.path.basename(imgmodel)
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
friendlysdmodelname = sanitize_string(friendlysdmodelname)
loadok = sd_load_model(imgmodel,imgvae,imglora)
loadok = sd_load_model(imgmodel,imgvae,imglora,imgt5xxl,imgclipl,imgclipg)
print("Load Image Model OK: " + str(loadok))
if not loadok:
exitcounter = 999
@ -4625,6 +4671,9 @@ if __name__ == '__main__':
sdparsergroup.add_argument("--sdmodel", metavar=('[filename]'), help="Specify a stable diffusion safetensors or gguf model to enable image generation.", default="")
sdparsergroup.add_argument("--sdthreads", metavar=('[threads]'), help="Use a different number of threads for image generation if specified. Otherwise, has the same value as --threads.", type=int, default=0)
sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use. Accepts an extra optional parameter that indicates maximum resolution (eg. 768 clamps to 768x768, min 512px, disabled if 0).", nargs='?', const=512, type=int, default=0)
sdparsergroup.add_argument("--sdt5xxl", metavar=('[filename]'), help="Specify a T5-XXL safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify a stable diffusion safetensors VAE which replaces the one in the model.", default="")
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')

View file

@ -171,6 +171,9 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
std::string taesdpath = "";
std::string lorafilename = inputs.lora_filename;
std::string vaefilename = inputs.vae_filename;
std::string t5xxl_filename = inputs.t5xxl_filename;
std::string clipl_filename = inputs.clipl_filename;
std::string clipg_filename = inputs.clipg_filename;
printf("\nImageGen Init - Load Model: %s\n",inputs.model_filename);
if(lorafilename!="")
{
@ -185,6 +188,18 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
{
printf("With Custom VAE: %s\n",vaefilename.c_str());
}
if(t5xxl_filename!="")
{
printf("With Custom T5-XXL Model: %s\n",t5xxl_filename.c_str());
}
if(clipl_filename!="")
{
printf("With Custom Clip-L Model: %s\n",clipl_filename.c_str());
}
if(clipg_filename!="")
{
printf("With Custom Clip-G Model: %s\n",clipg_filename.c_str());
}
//duplicated from expose.cpp
int cl_parseinfo = inputs.clblast_info; //first digit is whether configured, second is platform, third is devices
@ -219,6 +234,17 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
sd_params->batch_count = 1;
sd_params->vae_path = vaefilename;
sd_params->taesd_path = taesdpath;
sd_params->t5xxl_path = t5xxl_filename;
sd_params->clip_l_path = clipl_filename;
sd_params->clip_g_path = clipg_filename;
//if clip and t5 is set, and model is a gguf, load it as a diffusion model path
bool endswithgguf = (sd_params->model_path.rfind(".gguf") == sd_params->model_path.size() - 5);
if(sd_params->clip_l_path!="" && sd_params->t5xxl_path!="" && endswithgguf)
{
printf("\nSwap to Diffusion Model Path:%s",sd_params->model_path.c_str());
sd_params->diffusion_model_path = sd_params->model_path;
sd_params->model_path = "";
}
sddebugmode = inputs.debugmode;