diff --git a/expose.cpp b/expose.cpp index aa0daaaea..67d634d56 100644 --- a/expose.cpp +++ b/expose.cpp @@ -32,6 +32,7 @@ extern "C" { std::string model = inputs.model_filename; lora_filename = inputs.lora_filename; + lora_base = inputs.lora_base; int forceversion = inputs.forceversion; diff --git a/expose.h b/expose.h index 8ec9fa42a..b9e97eafa 100644 --- a/expose.h +++ b/expose.h @@ -11,6 +11,7 @@ struct load_model_inputs const char * executable_path; const char * model_filename; const char * lora_filename; + const char * lora_base; const bool use_mmap; const bool use_mlock; const bool use_smartcontext; @@ -49,5 +50,6 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; +extern std::string lora_base; extern std::vector generated_tokens; extern bool generation_finished; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index adb30672a..44997ff2f 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -31,6 +31,7 @@ //shared std::string executable_path = ""; std::string lora_filename = ""; +std::string lora_base = ""; bool generation_finished; std::vector generated_tokens; @@ -341,9 +342,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in { printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str()); + const char * lora_base_arg = NULL; + if (lora_base != "") { + printf("Using LORA base model: %s\n", lora_base.c_str()); + lora_base_arg = lora_base.c_str(); + } + int err = llama_v2_apply_lora_from_file(llama_ctx_v2, lora_filename.c_str(), - NULL, + lora_base_arg, n_threads); if (err != 0) { diff --git a/koboldcpp.py b/koboldcpp.py index 851296531..6d4540ab1 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure): ("executable_path", ctypes.c_char_p), ("model_filename", ctypes.c_char_p), ("lora_filename", ctypes.c_char_p), + ("lora_base", ctypes.c_char_p), ("use_mmap", ctypes.c_bool), ("use_mlock", ctypes.c_bool), ("use_smartcontext", ctypes.c_bool), @@ -146,7 +147,6 @@ def init_library(): def load_model(model_filename): inputs = load_model_inputs() inputs.model_filename = model_filename.encode("UTF-8") - inputs.lora_filename = args.lora.encode("UTF-8") inputs.batch_size = 8 inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten inputs.threads = args.threads @@ -154,8 +154,13 @@ def load_model(model_filename): inputs.f16_kv = True inputs.use_mmap = (not args.nommap) inputs.use_mlock = args.usemlock - if args.lora and args.lora!="": + inputs.lora_filename = "" + inputs.lora_base = "" + if args.lora: + inputs.lora_filename = args.lora[0].encode("UTF-8") inputs.use_mmap = False + if len(args.lora) > 1: + inputs.lora_base = args.lora[1].encode("UTF-8") inputs.use_smartcontext = args.smartcontext inputs.unban_tokens = args.unbantokens inputs.blasbatchsize = args.blasbatchsize @@ -744,13 +749,20 @@ def main(args): time.sleep(2) sys.exit(2) - if args.lora and args.lora!="": - if not os.path.exists(args.lora): - print(f"Cannot find lora file: {args.lora}") + if args.lora and args.lora[0]!="": + if not os.path.exists(args.lora[0]): + print(f"Cannot find lora file: {args.lora[0]}") time.sleep(2) sys.exit(2) else: - args.lora = os.path.abspath(args.lora) + args.lora[0] = os.path.abspath(args.lora[0]) + if len(args.lora) > 1: + if not os.path.exists(args.lora[1]): + print(f"Cannot find lora base: {args.lora[1]}") + time.sleep(2) + sys.exit(2) + else: + args.lora[1] = os.path.abspath(args.lora[1]) if args.psutil_set_threads: import psutil @@ -807,7 +819,7 @@ if __name__ == '__main__': portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store') parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="") parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true') - parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="") + parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+') physical_core_limit = 1 if os.cpu_count()!=None and os.cpu_count()>1: physical_core_limit = int(os.cpu_count()/2)