mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
allow custom t5, clipl and clipg
This commit is contained in:
parent
3cfc4dc581
commit
ccbd630a42
3 changed files with 99 additions and 21 deletions
3
expose.h
3
expose.h
|
@ -143,6 +143,9 @@ struct sd_load_model_inputs
|
|||
const int threads = 0;
|
||||
const int quant = 0;
|
||||
const bool taesd = false;
|
||||
const char * t5xxl_filename = nullptr;
|
||||
const char * clipl_filename = nullptr;
|
||||
const char * clipg_filename = nullptr;
|
||||
const char * vae_filename = nullptr;
|
||||
const char * lora_filename = nullptr;
|
||||
const float lora_multiplier = 1.0f;
|
||||
|
|
71
koboldcpp.py
71
koboldcpp.py
|
@ -207,6 +207,9 @@ class sd_load_model_inputs(ctypes.Structure):
|
|||
("threads", ctypes.c_int),
|
||||
("quant", ctypes.c_int),
|
||||
("taesd", ctypes.c_bool),
|
||||
("t5xxl_filename", ctypes.c_char_p),
|
||||
("clipl_filename", ctypes.c_char_p),
|
||||
("clipg_filename", ctypes.c_char_p),
|
||||
("vae_filename", ctypes.c_char_p),
|
||||
("lora_filename", ctypes.c_char_p),
|
||||
("lora_multiplier", ctypes.c_float),
|
||||
|
@ -1098,7 +1101,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
|
||||
|
||||
|
||||
def sd_load_model(model_filename,vae_filename,lora_filename):
|
||||
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl_filename,clipg_filename):
|
||||
global args
|
||||
inputs = sd_load_model_inputs()
|
||||
inputs.debugmode = args.debugmode
|
||||
|
@ -1120,6 +1123,9 @@ def sd_load_model(model_filename,vae_filename,lora_filename):
|
|||
inputs.vae_filename = vae_filename.encode("UTF-8")
|
||||
inputs.lora_filename = lora_filename.encode("UTF-8")
|
||||
inputs.lora_multiplier = args.sdloramult
|
||||
inputs.t5xxl_filename = t5xxl_filename.encode("UTF-8")
|
||||
inputs.clipl_filename = clipl_filename.encode("UTF-8")
|
||||
inputs.clipg_filename = clipg_filename.encode("UTF-8")
|
||||
inputs = set_backend_props(inputs)
|
||||
ret = handle.sd_load_model(inputs)
|
||||
return ret
|
||||
|
@ -2521,6 +2527,9 @@ def show_gui():
|
|||
sd_lora_var = ctk.StringVar()
|
||||
sd_loramult_var = ctk.StringVar(value="1.0")
|
||||
sd_vae_var = ctk.StringVar()
|
||||
sd_t5xxl_var = ctk.StringVar()
|
||||
sd_clipl_var = ctk.StringVar()
|
||||
sd_clipg_var = ctk.StringVar()
|
||||
sd_vaeauto_var = ctk.IntVar(value=0)
|
||||
sd_clamped_var = ctk.StringVar(value="0")
|
||||
sd_threads_var = ctk.StringVar(value=str(default_threads))
|
||||
|
@ -2603,6 +2612,10 @@ def show_gui():
|
|||
entry = ctk.CTkEntry(parent, width, textvariable=var)
|
||||
button = ctk.CTkButton(parent, 50, text="Browse", command= lambda a=var,b=searchtext:getfilename(a,b))
|
||||
if singlerow:
|
||||
if singlecol:
|
||||
entry.grid(row=row, column=0, padx=(84+8), stick="w")
|
||||
button.grid(row=row, column=0, padx=(84+width+12), stick="nw")
|
||||
else:
|
||||
entry.grid(row=row, column=1, padx=8, stick="w")
|
||||
button.grid(row=row, column=1, padx=(width+12), stick="nw")
|
||||
else:
|
||||
|
@ -2995,8 +3008,8 @@ def show_gui():
|
|||
makecheckbox(network_tab, "Quiet Mode", quietmode, 4,tooltiptxt="Prevents all generation related terminal output from being displayed.")
|
||||
makecheckbox(network_tab, "NoCertify Mode (Insecure)", nocertifymode, 4, 1,tooltiptxt="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.")
|
||||
|
||||
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=200 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
|
||||
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=200, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True,tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
|
||||
makefileentry(network_tab, "SSL Cert:", "Select SSL cert.pem file",ssl_cert_var, 5, width=200 ,filetypes=[("Unencrypted Certificate PEM", "*.pem")], singlerow=True, singlecol=False,tooltiptxt="Select your unencrypted .pem SSL certificate file for https.\nCan be generated with OpenSSL.")
|
||||
makefileentry(network_tab, "SSL Key:", "Select SSL key.pem file", ssl_key_var, 7, width=200, filetypes=[("Unencrypted Key PEM", "*.pem")], singlerow=True, singlecol=False, tooltiptxt="Select your unencrypted .pem SSL key file for https.\nCan be generated with OpenSSL.")
|
||||
makelabelentry(network_tab, "Password: ", password_var, 8, 200,tooltip="Enter a password required to use this instance.\nThis key will be required for all text endpoints.\nImage endpoints are not secured.")
|
||||
|
||||
# Horde Tab
|
||||
|
@ -3030,13 +3043,13 @@ def show_gui():
|
|||
# Image Gen Tab
|
||||
|
||||
images_tab = tabcontent["Image Gen"]
|
||||
makefileentry(images_tab, "Stable Diffusion Model (safetensors/gguf):", "Select Stable Diffusion Model File", sd_model_var, 1, width=280, singlecol=False, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")], tooltiptxt="Select a .safetensors or .gguf Stable Diffusion model file on disk to be loaded.")
|
||||
makelabelentry(images_tab, "Clamped Mode (Limit Resolution)", sd_clamped_var, 4, 50,tooltip="Limit generation steps and resolution settings for shared use.\nSet to 0 to disable, otherwise value is the size limit (min 512px).")
|
||||
makelabelentry(images_tab, "Image Threads:" , sd_threads_var, 6, 50,tooltip="How many threads to use during image generation.\nIf left blank, uses same value as threads.")
|
||||
makefileentry(images_tab, "Stable Diffusion Model (safetensors/gguf):", "Select Stable Diffusion Model File", sd_model_var, 1, width=280, singlecol=True, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")], tooltiptxt="Select a .safetensors or .gguf Stable Diffusion model file on disk to be loaded.")
|
||||
makelabelentry(images_tab, "Clamped Mode (Limit Resolution):", sd_clamped_var, 4, 50, padx=290,singleline=True,tooltip="Limit generation steps and resolution settings for shared use.\nSet to 0 to disable, otherwise value is the size limit (min 512px).")
|
||||
makelabelentry(images_tab, "Image Threads:" , sd_threads_var, 6, 50,padx=290,singleline=True,tooltip="How many threads to use during image generation.\nIf left blank, uses same value as threads.")
|
||||
sd_model_var.trace("w", gui_changed_modelfile)
|
||||
|
||||
sdloritem1,sdloritem2,sdloritem3 = makefileentry(images_tab, "Image LoRA (Must be non-quant):", "Select SD lora file",sd_lora_var, 10, width=280, singlecol=False, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded.")
|
||||
sdloritem4,sdloritem5 = makelabelentry(images_tab, "Image LoRA Multiplier:" , sd_loramult_var, 12, 50,tooltip="What mutiplier value to apply the SD LoRA with.")
|
||||
sdloritem1,sdloritem2,sdloritem3 = makefileentry(images_tab, "Image LoRA (Must be non-quant):", "Select SD lora file",sd_lora_var, 10, width=280, singlecol=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded.")
|
||||
sdloritem4,sdloritem5 = makelabelentry(images_tab, "Image LoRA Multiplier:" , sd_loramult_var, 12, 50,padx=290,singleline=True,tooltip="What mutiplier value to apply the SD LoRA with.")
|
||||
def togglesdquant(a,b,c):
|
||||
if sd_quant_var.get()==1:
|
||||
sdloritem1.grid_remove()
|
||||
|
@ -3045,6 +3058,7 @@ def show_gui():
|
|||
sdloritem4.grid_remove()
|
||||
sdloritem5.grid_remove()
|
||||
else:
|
||||
if not sdloritem1.grid_info() or not sdloritem2.grid_info() or not sdloritem3.grid_info() or not sdloritem4.grid_info() or not sdloritem5.grid_info():
|
||||
sdloritem1.grid()
|
||||
sdloritem2.grid()
|
||||
sdloritem3.grid()
|
||||
|
@ -3053,17 +3067,22 @@ def show_gui():
|
|||
makecheckbox(images_tab, "Compress Weights (Saves Memory)", sd_quant_var, 8,command=togglesdquant,tooltiptxt="Quantizes the SD model weights to save memory. May degrade quality.")
|
||||
sd_quant_var.trace("w", changed_gpulayers_estimate)
|
||||
|
||||
sdvaeitem1,sdvaeitem2,sdvaeitem3 = makefileentry(images_tab, "Image VAE:", "Select SD VAE file",sd_vae_var, 14, width=280, singlecol=False, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD VAE file to be loaded.")
|
||||
makefileentry(images_tab, "T5-XXL File:", "Select Optional T5-XXL model file (SD3 or flux)",sd_t5xxl_var, 14, width=280, singlerow=True, filetypes=[("*.safetensors", "*.safetensors")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
|
||||
makefileentry(images_tab, "Clip-L File:", "Select Optional Clip-L model file (SD3 or flux)",sd_clipl_var, 16, width=280, singlerow=True, filetypes=[("*.safetensors", "*.safetensors")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
|
||||
makefileentry(images_tab, "Clip-G File:", "Select Optional Clip-G model file (SD3)",sd_clipg_var, 18, width=280, singlerow=True, filetypes=[("*.safetensors", "*.safetensors")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
|
||||
|
||||
sdvaeitem1,sdvaeitem2,sdvaeitem3 = makefileentry(images_tab, "Image VAE:", "Select Optional SD VAE file",sd_vae_var, 20, width=280, singlerow=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD VAE file to be loaded.")
|
||||
def toggletaesd(a,b,c):
|
||||
if sd_vaeauto_var.get()==1:
|
||||
sdvaeitem1.grid_remove()
|
||||
sdvaeitem2.grid_remove()
|
||||
sdvaeitem3.grid_remove()
|
||||
else:
|
||||
if not sdvaeitem1.grid_info() or not sdvaeitem2.grid_info() or not sdvaeitem3.grid_info():
|
||||
sdvaeitem1.grid()
|
||||
sdvaeitem2.grid()
|
||||
sdvaeitem3.grid()
|
||||
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 16,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
|
||||
makecheckbox(images_tab, "Use TAE SD (AutoFix Broken VAE)", sd_vaeauto_var, 22,command=toggletaesd,tooltiptxt="Replace VAE with TAESD. May fix bad VAE.")
|
||||
|
||||
# audio tab
|
||||
audio_tab = tabcontent["Audio"]
|
||||
|
@ -3246,6 +3265,12 @@ def show_gui():
|
|||
args.sdvae = ""
|
||||
if sd_vae_var.get() != "":
|
||||
args.sdvae = sd_vae_var.get()
|
||||
if sd_t5xxl_var.get() != "":
|
||||
args.sdt5xxl = sd_t5xxl_var.get()
|
||||
if sd_clipl_var.get() != "":
|
||||
args.sdclipl = sd_clipl_var.get()
|
||||
if sd_clipg_var.get() != "":
|
||||
args.sdclipg = sd_clipg_var.get()
|
||||
if sd_quant_var.get()==1:
|
||||
args.sdquant = True
|
||||
args.sdlora = ""
|
||||
|
@ -3396,6 +3421,9 @@ def show_gui():
|
|||
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
|
||||
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
|
||||
sd_vae_var.set(dict["sdvae"] if ("sdvae" in dict and dict["sdvae"]) else "")
|
||||
sd_t5xxl_var.set(dict["sdt5xxl"] if ("sdt5xxl" in dict and dict["sdt5xxl"]) else "")
|
||||
sd_clipl_var.set(dict["sdclipl"] if ("sdclipl" in dict and dict["sdclipl"]) else "")
|
||||
sd_clipg_var.set(dict["sdclipg"] if ("sdclipg" in dict and dict["sdclipg"]) else "")
|
||||
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
|
||||
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
|
||||
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
|
||||
|
@ -4307,6 +4335,9 @@ def main(launch_args,start_server=True):
|
|||
else:
|
||||
imglora = ""
|
||||
imgvae = ""
|
||||
imgt5xxl = ""
|
||||
imgclipl = ""
|
||||
imgclipg = ""
|
||||
if args.sdlora:
|
||||
if os.path.exists(args.sdlora):
|
||||
imglora = os.path.abspath(args.sdlora)
|
||||
|
@ -4317,13 +4348,28 @@ def main(launch_args,start_server=True):
|
|||
imgvae = os.path.abspath(args.sdvae)
|
||||
else:
|
||||
print(f"Missing SD VAE model file...")
|
||||
if args.sdt5xxl:
|
||||
if os.path.exists(args.sdt5xxl):
|
||||
imgt5xxl = os.path.abspath(args.sdt5xxl)
|
||||
else:
|
||||
print(f"Missing SD T5-XXL model file...")
|
||||
if args.sdclipl:
|
||||
if os.path.exists(args.sdclipl):
|
||||
imgclipl = os.path.abspath(args.sdclipl)
|
||||
else:
|
||||
print(f"Missing SD Clip-L model file...")
|
||||
if args.sdclipg:
|
||||
if os.path.exists(args.sdclipg):
|
||||
imgclipg = os.path.abspath(args.sdclipg)
|
||||
else:
|
||||
print(f"Missing SD Clip-G model file...")
|
||||
|
||||
imgmodel = os.path.abspath(imgmodel)
|
||||
fullsdmodelpath = imgmodel
|
||||
friendlysdmodelname = os.path.basename(imgmodel)
|
||||
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
|
||||
friendlysdmodelname = sanitize_string(friendlysdmodelname)
|
||||
loadok = sd_load_model(imgmodel,imgvae,imglora)
|
||||
loadok = sd_load_model(imgmodel,imgvae,imglora,imgt5xxl,imgclipl,imgclipg)
|
||||
print("Load Image Model OK: " + str(loadok))
|
||||
if not loadok:
|
||||
exitcounter = 999
|
||||
|
@ -4625,6 +4671,9 @@ if __name__ == '__main__':
|
|||
sdparsergroup.add_argument("--sdmodel", metavar=('[filename]'), help="Specify a stable diffusion safetensors or gguf model to enable image generation.", default="")
|
||||
sdparsergroup.add_argument("--sdthreads", metavar=('[threads]'), help="Use a different number of threads for image generation if specified. Otherwise, has the same value as --threads.", type=int, default=0)
|
||||
sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use. Accepts an extra optional parameter that indicates maximum resolution (eg. 768 clamps to 768x768, min 512px, disabled if 0).", nargs='?', const=512, type=int, default=0)
|
||||
sdparsergroup.add_argument("--sdt5xxl", metavar=('[filename]'), help="Specify a T5-XXL safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
|
||||
sdparsergroup.add_argument("--sdclipl", metavar=('[filename]'), help="Specify a Clip-L safetensors model for use in SD3 or Flux. Leave blank if prebaked or unused.", default="")
|
||||
sdparsergroup.add_argument("--sdclipg", metavar=('[filename]'), help="Specify a Clip-G safetensors model for use in SD3. Leave blank if prebaked or unused.", default="")
|
||||
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
|
||||
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify a stable diffusion safetensors VAE which replaces the one in the model.", default="")
|
||||
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
|
||||
|
|
|
@ -171,6 +171,9 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
std::string taesdpath = "";
|
||||
std::string lorafilename = inputs.lora_filename;
|
||||
std::string vaefilename = inputs.vae_filename;
|
||||
std::string t5xxl_filename = inputs.t5xxl_filename;
|
||||
std::string clipl_filename = inputs.clipl_filename;
|
||||
std::string clipg_filename = inputs.clipg_filename;
|
||||
printf("\nImageGen Init - Load Model: %s\n",inputs.model_filename);
|
||||
if(lorafilename!="")
|
||||
{
|
||||
|
@ -185,6 +188,18 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
{
|
||||
printf("With Custom VAE: %s\n",vaefilename.c_str());
|
||||
}
|
||||
if(t5xxl_filename!="")
|
||||
{
|
||||
printf("With Custom T5-XXL Model: %s\n",t5xxl_filename.c_str());
|
||||
}
|
||||
if(clipl_filename!="")
|
||||
{
|
||||
printf("With Custom Clip-L Model: %s\n",clipl_filename.c_str());
|
||||
}
|
||||
if(clipg_filename!="")
|
||||
{
|
||||
printf("With Custom Clip-G Model: %s\n",clipg_filename.c_str());
|
||||
}
|
||||
|
||||
//duplicated from expose.cpp
|
||||
int cl_parseinfo = inputs.clblast_info; //first digit is whether configured, second is platform, third is devices
|
||||
|
@ -219,6 +234,17 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
sd_params->batch_count = 1;
|
||||
sd_params->vae_path = vaefilename;
|
||||
sd_params->taesd_path = taesdpath;
|
||||
sd_params->t5xxl_path = t5xxl_filename;
|
||||
sd_params->clip_l_path = clipl_filename;
|
||||
sd_params->clip_g_path = clipg_filename;
|
||||
//if clip and t5 is set, and model is a gguf, load it as a diffusion model path
|
||||
bool endswithgguf = (sd_params->model_path.rfind(".gguf") == sd_params->model_path.size() - 5);
|
||||
if(sd_params->clip_l_path!="" && sd_params->t5xxl_path!="" && endswithgguf)
|
||||
{
|
||||
printf("\nSwap to Diffusion Model Path:%s",sd_params->model_path.c_str());
|
||||
sd_params->diffusion_model_path = sd_params->model_path;
|
||||
sd_params->model_path = "";
|
||||
}
|
||||
|
||||
sddebugmode = inputs.debugmode;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue