mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 09:02:04 +00:00
support loading multiple sd loras (up to 4 at once)
This commit is contained in:
parent
a089284d13
commit
bf3f2e1ba8
3 changed files with 106 additions and 49 deletions
97
koboldcpp.py
97
koboldcpp.py
|
|
@ -59,6 +59,7 @@ default_vae_tile_threshold = 768
|
|||
default_native_ctx = 16384
|
||||
overridekv_max = 4
|
||||
default_autofit_padding = 1024
|
||||
lora_filenames_max = 4
|
||||
|
||||
# abuse prevention
|
||||
stop_token_max = 256
|
||||
|
|
@ -311,7 +312,7 @@ class sd_load_model_inputs(ctypes.Structure):
|
|||
("clip1_filename", ctypes.c_char_p),
|
||||
("clip2_filename", ctypes.c_char_p),
|
||||
("vae_filename", ctypes.c_char_p),
|
||||
("lora_filename", ctypes.c_char_p),
|
||||
("lora_filenames", ctypes.c_char_p * lora_filenames_max),
|
||||
("lora_multiplier", ctypes.c_float),
|
||||
("lora_apply_mode", ctypes.c_int),
|
||||
("photomaker_filename", ctypes.c_char_p),
|
||||
|
|
@ -1931,7 +1932,7 @@ def sd_quant_option(value):
|
|||
except Exception:
|
||||
return 0
|
||||
|
||||
def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clip1_filename,clip2_filename,photomaker_filename,upscaler_filename):
|
||||
def sd_load_model(model_filename,vae_filename,lora_filenames,t5xxl_filename,clip1_filename,clip2_filename,photomaker_filename,upscaler_filename):
|
||||
global args
|
||||
inputs = sd_load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
|
|
@ -1954,7 +1955,12 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clip1
|
|||
inputs.taesd = True if args.sdvaeauto else False
|
||||
inputs.tiled_vae_threshold = args.sdtiledvae
|
||||
inputs.vae_filename = vae_filename.encode("UTF-8")
|
||||
inputs.lora_filename = lora_filename.encode("UTF-8")
|
||||
for n in range(lora_filenames_max):
|
||||
if n >= len(lora_filenames):
|
||||
inputs.lora_filenames[n] = "".encode("UTF-8")
|
||||
else:
|
||||
inputs.lora_filenames[n] = lora_filenames[n].encode("UTF-8")
|
||||
|
||||
inputs.lora_multiplier = args.sdloramult
|
||||
inputs.t5xxl_filename = t5xxl_filename.encode("UTF-8")
|
||||
inputs.clip1_filename = clip1_filename.encode("UTF-8")
|
||||
|
|
@ -5173,7 +5179,7 @@ def RunServerMultiThreaded(addr, port, server_handler):
|
|||
sys.exit(0)
|
||||
|
||||
# Based on https://github.com/mathgeniuszach/xdialog/blob/main/xdialog/zenity_dialogs.py - MIT license | - Expanded version by Henk717
|
||||
def zenity(filetypes=None, initialdir="", initialfile="", **kwargs) -> Tuple[int, str]:
|
||||
def zenity(filetypes=None, initialdir="", initialfile="", multiple=False, **kwargs) -> Tuple[int, object]:
|
||||
global zenity_recent_dir, zenity_permitted
|
||||
|
||||
if not zenity_permitted:
|
||||
|
|
@ -5238,6 +5244,10 @@ def zenity(filetypes=None, initialdir="", initialfile="", **kwargs) -> Tuple[int
|
|||
initialpath = os.path.join(initialdir, initialfile)
|
||||
args.append(f'--filename={initialpath}')
|
||||
|
||||
if multiple:
|
||||
args.append("--multiple")
|
||||
args.append("--separator=|")
|
||||
|
||||
clean_env = os.environ.copy()
|
||||
clean_env.pop("LD_LIBRARY_PATH", None)
|
||||
clean_env["PATH"] = "/usr/bin:/bin"
|
||||
|
|
@ -5252,15 +5262,18 @@ def zenity(filetypes=None, initialdir="", initialfile="", **kwargs) -> Tuple[int
|
|||
result = procres.stdout.decode('utf-8').strip()
|
||||
if procres.returncode==0 and result:
|
||||
directory = result
|
||||
if not os.path.isdir(result):
|
||||
directory = os.path.dirname(result)
|
||||
if multiple:
|
||||
result = tuple(result.split("|"))
|
||||
directory = result[0]
|
||||
if not os.path.isdir(directory):
|
||||
directory = os.path.dirname(directory)
|
||||
zenity_recent_dir = directory
|
||||
return (procres.returncode, result)
|
||||
|
||||
# note: In this section we wrap around file dialogues to allow for zenity
|
||||
def zentk_askopenfilename(**options):
|
||||
try:
|
||||
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), title=options.get("title"))[1]
|
||||
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), multiple=False, title=options.get("title"))[1]
|
||||
if result and not os.path.isfile(result):
|
||||
print("A folder was selected while we need a file, ignoring selection.")
|
||||
return ''
|
||||
|
|
@ -5269,9 +5282,21 @@ def zentk_askopenfilename(**options):
|
|||
result = askopenfilename(**options)
|
||||
return result
|
||||
|
||||
def zentk_askopenfilenames(**options):
|
||||
try:
|
||||
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), multiple=True, title=options.get("title"))[1]
|
||||
for itm in result:
|
||||
if itm and not os.path.isfile(itm):
|
||||
print("A folder was selected while we need a file, ignoring selection.")
|
||||
return ''
|
||||
except Exception:
|
||||
from tkinter.filedialog import askopenfilenames
|
||||
result = askopenfilenames(**options)
|
||||
return result
|
||||
|
||||
def zentk_askdirectory(**options):
|
||||
try:
|
||||
result = zenity(initialdir=options.get("initialdir"), title=options.get("title"), directory=True)[1]
|
||||
result = zenity(initialdir=options.get("initialdir"), multiple=False, title=options.get("title"), directory=True)[1]
|
||||
except Exception:
|
||||
from tkinter.filedialog import askdirectory
|
||||
result = askdirectory(**options)
|
||||
|
|
@ -5279,7 +5304,7 @@ def zentk_askdirectory(**options):
|
|||
|
||||
def zentk_asksaveasfilename(**options):
|
||||
try:
|
||||
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), initialfile=options.get("initialfile"), title=options.get("title"), save=True)[1]
|
||||
result = zenity(filetypes=options.get("filetypes"), initialdir=options.get("initialdir"), initialfile=options.get("initialfile"), multiple=False, title=options.get("title"), save=True)[1]
|
||||
except Exception:
|
||||
from tkinter.filedialog import asksaveasfilename
|
||||
result = asksaveasfilename(**options)
|
||||
|
|
@ -5724,7 +5749,7 @@ def show_gui():
|
|||
return entry, label
|
||||
|
||||
#file dialog types: 0=openfile,1=savefile,2=opendir
|
||||
def makefileentry(parent, text, searchtext, var, row=0, width=200, filetypes=[], onchoosefile=None, singlerow=False, singlecol=True, dialog_type=0, tooltiptxt=""):
|
||||
def makefileentry(parent, text, searchtext, var, row=0, width=200, filetypes=[], onchoosefile=None, singlerow=False, singlecol=True, dialog_type=0, tooltiptxt="", multiple=False):
|
||||
label = makelabel(parent, text, row,0,tooltiptxt,columnspan=3)
|
||||
def getfilename(var, text):
|
||||
initialDir = os.path.dirname(var.get())
|
||||
|
|
@ -5740,7 +5765,11 @@ def show_gui():
|
|||
fnam = str(fnam).strip()
|
||||
fnam = f"{fnam}.jsondb" if ".jsondb" not in fnam.lower() else fnam
|
||||
else:
|
||||
fnam = zentk_askopenfilename(title=text,filetypes=filetypes, initialdir=initialDir)
|
||||
if multiple:
|
||||
fnam = zentk_askopenfilenames(title=text,filetypes=filetypes, initialdir=initialDir)
|
||||
fnam = "|".join(fnam)
|
||||
else:
|
||||
fnam = zentk_askopenfilename(title=text,filetypes=filetypes, initialdir=initialDir)
|
||||
if fnam:
|
||||
var.set(fnam)
|
||||
if onchoosefile:
|
||||
|
|
@ -6383,7 +6412,7 @@ def show_gui():
|
|||
makelabelcombobox(images_tab, "Compress Weights: ", sd_quant_var, 8, width=(60), padx=(126), labelpadx=8, tooltiptxt="Quantizes the SD model weights to save memory.\nHigher levels save more memory, and cause more quality degradation.", values=sd_quant_choices)
|
||||
sd_quant_var.trace_add("write", changed_gpulayers_estimate)
|
||||
|
||||
makefileentry(images_tab, "Image LoRA:", "Select SD lora file",sd_lora_var, 20, width=160, singlerow=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded. Should be unquantized!")
|
||||
makefileentry(images_tab, "Image LoRA:", "Select SD lora file",sd_lora_var, 20, width=160, singlerow=True, filetypes=[("*.safetensors *.gguf", "*.safetensors *.gguf")],tooltiptxt="Select a .safetensors or .gguf SD LoRA model file to be loaded. Should be unquantized!", multiple=True)
|
||||
makelabelentry(images_tab, "Multiplier:" , sd_loramult_var, 20, 50,padx=(390),singleline=True,tooltip="What mutiplier value to apply the SD LoRA with.",labelpadx=(330))
|
||||
|
||||
makefileentry(images_tab, "T5-XXL File:", "Select T5-XXL model file (SD3, Flux, WAN)",sd_t5xxl_var, 24, width=280, singlerow=True, filetypes=[("*.safetensors *.gguf","*.safetensors *.gguf")],tooltiptxt="Select a .safetensors t5xxl file to be loaded.")
|
||||
|
|
@ -6711,10 +6740,10 @@ def show_gui():
|
|||
args.sdupscaler = sd_upscaler_var.get()
|
||||
args.sdquant = sd_quant_option(sd_quant_var.get())
|
||||
if sd_lora_var.get() != "":
|
||||
args.sdlora = sd_lora_var.get()
|
||||
args.sdlora = [item.strip() for item in sd_lora_var.get().split("|") if item]
|
||||
args.sdloramult = float(sd_loramult_var.get())
|
||||
else:
|
||||
args.sdlora = ""
|
||||
args.sdlora = None
|
||||
|
||||
if gen_defaults_var.get() != "":
|
||||
args.gendefaults = gen_defaults_var.get()
|
||||
|
|
@ -6959,8 +6988,13 @@ def show_gui():
|
|||
sd_upscaler_var.set(dict["sdupscaler"] if ("sdupscaler" in dict and dict["sdupscaler"]) else "")
|
||||
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
|
||||
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
|
||||
|
||||
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
|
||||
if "sdlora" in dict and dict["sdlora"]:
|
||||
if isinstance((dict["sdlora"]), list):
|
||||
sd_lora_var.set("|".join(dict["sdlora"]))
|
||||
else:
|
||||
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
|
||||
else:
|
||||
sd_lora_var.set("")
|
||||
sd_loramult_var.set(str(dict["sdloramult"]) if ("sdloramult" in dict and dict["sdloramult"]) else "1.0")
|
||||
gen_defaults_var.set(dict["gendefaults"] if ("gendefaults" in dict and dict["gendefaults"]) else "")
|
||||
gen_defaults_overwrite_var.set(1 if "gendefaultsoverwrite" in dict and dict["gendefaultsoverwrite"] else 0)
|
||||
|
|
@ -7401,6 +7435,8 @@ def convert_invalid_args(args):
|
|||
dict["gendefaults"] = dict["sdgendefaults"]
|
||||
if "flashattention" in dict and "noflashattention" not in dict:
|
||||
dict["noflashattention"] = not dict["flashattention"]
|
||||
if "sdlora" in dict and isinstance(dict["sdlora"], str):
|
||||
dict["sdlora"] = ([dict["sdlora"]] if dict["sdlora"] else None)
|
||||
return args
|
||||
|
||||
def setuptunnel(global_memory, has_sd):
|
||||
|
|
@ -8220,10 +8256,11 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
dlfile = download_model_from_url(args.sdvae,[".gguf",".safetensors"],min_file_size=500000)
|
||||
if dlfile:
|
||||
args.sdvae = dlfile
|
||||
if args.sdlora and args.sdlora!="":
|
||||
dlfile = download_model_from_url(args.sdlora,[".gguf",".safetensors"],min_file_size=500000)
|
||||
if dlfile:
|
||||
args.sdlora = dlfile
|
||||
if args.sdlora and len(args.sdlora)>0:
|
||||
for i in range(0,len(args.sdlora)):
|
||||
dlfile = download_model_from_url(args.sdlora[i],[".gguf",".safetensors"],min_file_size=500000)
|
||||
if dlfile:
|
||||
args.sdlora[i] = dlfile
|
||||
if args.mmproj and args.mmproj!="":
|
||||
dlfile = download_model_from_url(args.mmproj,[".gguf"],min_file_size=500000)
|
||||
if dlfile:
|
||||
|
|
@ -8499,18 +8536,20 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
exitcounter = 999
|
||||
exit_with_error(2,f"Cannot find image model file: {imgmodel}")
|
||||
else:
|
||||
imglora = ""
|
||||
imgloras = []
|
||||
imgvae = ""
|
||||
imgt5xxl = ""
|
||||
imgclip1 = ""
|
||||
imgclip2 = ""
|
||||
imgphotomaker = ""
|
||||
imgupscaler = ""
|
||||
if args.sdlora:
|
||||
if os.path.exists(args.sdlora):
|
||||
imglora = os.path.abspath(args.sdlora)
|
||||
else:
|
||||
print("Missing SD LORA model file...")
|
||||
if args.sdlora and len(args.sdlora)>0:
|
||||
for i in range (0,len(args.sdlora)):
|
||||
curr = args.sdlora[i]
|
||||
if os.path.exists(curr):
|
||||
imgloras.append(os.path.abspath(curr))
|
||||
else:
|
||||
print(f"Missing SD LORA model file {curr}...")
|
||||
if args.sdvae:
|
||||
if os.path.exists(args.sdvae):
|
||||
imgvae = os.path.abspath(args.sdvae)
|
||||
|
|
@ -8547,7 +8586,7 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
friendlysdmodelname = os.path.basename(imgmodel)
|
||||
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
|
||||
friendlysdmodelname = sanitize_string(friendlysdmodelname)
|
||||
loadok = sd_load_model(imgmodel,imgvae,imglora,imgt5xxl,imgclip1,imgclip2,imgphotomaker,imgupscaler)
|
||||
loadok = sd_load_model(imgmodel,imgvae,imgloras,imgt5xxl,imgclip1,imgclip2,imgphotomaker,imgupscaler)
|
||||
print("Load Image Model OK: " + str(loadok))
|
||||
if not loadok:
|
||||
exitcounter = 999
|
||||
|
|
@ -9008,8 +9047,8 @@ if __name__ == '__main__':
|
|||
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in tiny VAE via TAE SD, which is very fast, and fixed bad VAEs.", action='store_true')
|
||||
sdparsergrouplora = sdparsergroup.add_mutually_exclusive_group()
|
||||
sdparsergrouplora.add_argument("--sdquant", metavar=('[quantization level 0/1/2]'), help="If specified, loads the model quantized to save memory. 0=off, 1=q8, 2=q4", type=int, choices=[0,1,2], nargs="?", const=2, default=0)
|
||||
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify an image generation LORA safetensors model to be applied.", default="")
|
||||
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LORA model to be applied.", type=float, default=1.0)
|
||||
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify image generation LoRAs safetensors models to be applied. Multiple LoRAs are accepted.", nargs='+')
|
||||
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the image LoRA model to be applied.", type=float, default=1.0)
|
||||
sdparsergroup.add_argument("--sdtiledvae", metavar=('[maxres]'), help="Adjust the automatic VAE tiling trigger for images above this size. 0 disables vae tiling.", type=int, default=default_vae_tile_threshold)
|
||||
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
|
||||
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper .bin model to enable Speech-To-Text transcription.", default="")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue