diff --git a/koboldcpp.py b/koboldcpp.py index 370933d04..36efc4b18 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -2023,19 +2023,12 @@ def sd_load_model(model_filename,vae_filename,lora_filenames,t5xxl_filename,clip lora_filenames = [lf.encode("UTF-8") for lf in lora_filenames[:lora_filenames_max] if lf] lora_len = len(lora_filenames) - lora_multipliers = args.sdloramult[:lora_len] - if len(lora_multipliers) < lora_len: - missing = lora_len - len(lora_multipliers) - if len(lora_multipliers) == 1: - # previous behavior: all get the same weight - lora_multipliers.extend(lora_multipliers * missing) - else: - lora_multipliers.extend([0.] * missing) + lora_multipliers = prepare_lora_multipliers([]) inputs.lora_len = lora_len inputs.lora_filenames = (ctypes.c_char_p * lora_len)(*lora_filenames) inputs.lora_multipliers = (ctypes.c_float * lora_len)(*lora_multipliers) # auto if no zero-weight lora, dynamic otherwise - inputs.lora_apply_mode = 3 if 0. in inputs.lora_multipliers else 0 + inputs.lora_apply_mode = 3 if 0. in lora_multipliers else 0 inputs.img_hard_limit = args.sdclamped inputs.img_soft_limit = args.sdclampedsoft @@ -8491,20 +8484,19 @@ def mk_lora_info(imgloras, multipliers): # the full filename as a path, but we don't know if we can expose it used_lora_names = set() result = [] + first_multiplier = multipliers[0] if len(multipliers) > 0 else 1. for i, lora_path in enumerate(imgloras): - multiplier = 0. if i >= len(multipliers) else multipliers[i] + multiplier = multipliers[i] if i < len(multipliers) else first_multiplier lora_file = os.path.basename(lora_path) lora_name, lora_ext = os.path.splitext(lora_file) # ensure unique names i = 1 mapped_name = lora_name - while True: - if mapped_name not in used_lora_names: - result.append((lora_path, mapped_name, mapped_name + lora_ext, multiplier)) - used_lora_names.add(mapped_name) - break + while mapped_name in used_lora_names: i += 1 mapped_name = lora_name + '_' + str(i) + used_lora_names.add(mapped_name) + result.append((lora_path, mapped_name, mapped_name + lora_ext, multiplier)) return result diff --git a/tests/test_koboldcpp.py b/tests/test_koboldcpp.py index 403007709..11f942b61 100644 --- a/tests/test_koboldcpp.py +++ b/tests/test_koboldcpp.py @@ -53,6 +53,17 @@ def extract_loras_from_prompt(*args, **kwargs): return koboldcpp.extract_loras_from_prompt(*args, **kwargs) +def mk_lora_info(*args, **kwargs): + """ + >>> mk_lora_info(['/x/lora1.safetensors', '/y/lora2.gguf'], []) + [('/x/lora1.safetensors', 'lora1', 'lora1.safetensors', 1.0), ('/y/lora2.gguf', 'lora2', 'lora2.gguf', 1.0)] + >>> mk_lora_info(['/x/lora1.safetensors', '/y/lora1.safetensors'], [0.3]) + [('/x/lora1.safetensors', 'lora1', 'lora1.safetensors', 0.3), ('/y/lora1.safetensors', 'lora1_2', 'lora1_2.safetensors', 0.3)] + >>> mk_lora_info(['./lora1.gguf', '/y/lora2.gguf', 'lora3.gguf'], [0, 0.3]) + [('./lora1.gguf', 'lora1', 'lora1.gguf', 0), ('/y/lora2.gguf', 'lora2', 'lora2.gguf', 0.3), ('lora3.gguf', 'lora3', 'lora3.gguf', 0)] + """ + return koboldcpp.mk_lora_info(*args, **kwargs) + def sanitize_lora_multipliers(*args, **kwargs): """ >>> sanitize_lora_multipliers(None)