mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
sd: additional validation for the LoRA list (#2043)
* sd: additional validation for the LoRA list * sd: sanitize LoRA list before downloading
This commit is contained in:
parent
6e7b9a1549
commit
0c66ed863d
1 changed files with 16 additions and 14 deletions
30
koboldcpp.py
30
koboldcpp.py
|
|
@ -2187,6 +2187,15 @@ def sd_upscale(genparams):
|
|||
data_main = ret.data.decode("UTF-8","ignore")
|
||||
return data_main
|
||||
|
||||
def sanitize_lora_list(sdlora):
|
||||
if not sdlora:
|
||||
sdlora = []
|
||||
elif isinstance(sdlora, str):
|
||||
sdlora = [sdlora]
|
||||
elif not isinstance(sdlora, list):
|
||||
sdlora = []
|
||||
return sdlora
|
||||
|
||||
def sanitize_lora_multipliers(sdloramult):
|
||||
if sdloramult is None:
|
||||
sdloramult = [1.0]
|
||||
|
|
@ -7357,10 +7366,7 @@ def show_gui():
|
|||
if sd_upscaler_var.get() != "":
|
||||
args.sdupscaler = sd_upscaler_var.get()
|
||||
args.sdquant = sd_quant_option(sd_quant_var.get())
|
||||
if sd_lora_var.get() != "":
|
||||
args.sdlora = [item.strip() for item in sd_lora_var.get().split("|") if item]
|
||||
else:
|
||||
args.sdlora = None
|
||||
args.sdlora = [item.strip() for item in sd_lora_var.get().split("|") if item]
|
||||
# XXX the user may have used '|' since it's used for the LoRAs
|
||||
args.sdloramult = sanitize_lora_multipliers(re.split(r"[ |]+", sd_loramult_var.get()))
|
||||
|
||||
|
|
@ -7616,13 +7622,7 @@ def show_gui():
|
|||
sd_upscaler_var.set(dict["sdupscaler"] if ("sdupscaler" in dict and dict["sdupscaler"]) else "")
|
||||
sd_vaeauto_var.set(1 if ("sdvaeauto" in dict and dict["sdvaeauto"]) else 0)
|
||||
sd_tiled_vae_var.set(str(dict["sdtiledvae"]) if ("sdtiledvae" in dict and dict["sdtiledvae"]) else str(default_vae_tile_threshold))
|
||||
if "sdlora" in dict and dict["sdlora"]:
|
||||
if isinstance((dict["sdlora"]), list):
|
||||
sd_lora_var.set("|".join(dict["sdlora"]))
|
||||
else:
|
||||
sd_lora_var.set(dict["sdlora"] if ("sdlora" in dict and dict["sdlora"]) else "")
|
||||
else:
|
||||
sd_lora_var.set("")
|
||||
sd_lora_var.set("|".join(sanitize_lora_list(dict.get('sdlora'))))
|
||||
sd_loramult_var.set(" ".join(f"{n:.3f}".rstrip('0').rstrip('.') for n in dict.get("sdloramult", [])))
|
||||
gendefaults = (dict["gendefaults"] if ("gendefaults" in dict and dict["gendefaults"]) else "")
|
||||
if isinstance(gendefaults, type({})):
|
||||
|
|
@ -8072,8 +8072,8 @@ def convert_invalid_args(args):
|
|||
dict["gendefaults"] = dict["sdgendefaults"]
|
||||
if "flashattention" in dict and "noflashattention" not in dict:
|
||||
dict["noflashattention"] = not dict["flashattention"]
|
||||
if "sdlora" in dict and isinstance(dict["sdlora"], str):
|
||||
dict["sdlora"] = ([dict["sdlora"]] if dict["sdlora"] else None)
|
||||
if "sdlora" in dict:
|
||||
dict["sdlora"] = sanitize_lora_list(dict["sdlora"])
|
||||
if "sdloramult" in dict:
|
||||
dict["sdloramult"] = sanitize_lora_multipliers(dict["sdloramult"])
|
||||
return args
|
||||
|
|
@ -8905,6 +8905,9 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
if args.model_param and (args.benchmark or args.prompt or args.cli):
|
||||
start_server = False
|
||||
|
||||
args.sdlora = sanitize_lora_list(args.sdlora)
|
||||
args.sdloramult = sanitize_lora_multipliers(args.sdloramult)
|
||||
|
||||
#try to read story if provided
|
||||
if args.preloadstory:
|
||||
global preloaded_story
|
||||
|
|
@ -9328,7 +9331,6 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
imgphotomaker = ""
|
||||
imgupscaler = ""
|
||||
global imglora_preload, imglora_bypath, imglora_name2path
|
||||
args.sdloramult = sanitize_lora_multipliers(args.sdloramult)
|
||||
imglora_preload, imglora_bypath, imglora_name2path = mk_lora_info(args.sdlora, args.sdloramult)
|
||||
if args.sdvae:
|
||||
if os.path.exists(args.sdvae):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue