added support for lora base

This commit is contained in:
Concedo 2023-06-10 19:29:45 +08:00
parent 375540837e
commit 66a3f4e421
4 changed files with 30 additions and 8 deletions

View file

@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure):
("executable_path", ctypes.c_char_p),
("model_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p),
("lora_base", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
@ -146,7 +147,6 @@ def init_library():
def load_model(model_filename):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.lora_filename = args.lora.encode("UTF-8")
inputs.batch_size = 8
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
inputs.threads = args.threads
@ -154,8 +154,13 @@ def load_model(model_filename):
inputs.f16_kv = True
inputs.use_mmap = (not args.nommap)
inputs.use_mlock = args.usemlock
if args.lora and args.lora!="":
inputs.lora_filename = ""
inputs.lora_base = ""
if args.lora:
inputs.lora_filename = args.lora[0].encode("UTF-8")
inputs.use_mmap = False
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.unban_tokens = args.unbantokens
inputs.blasbatchsize = args.blasbatchsize
@ -744,13 +749,20 @@ def main(args):
time.sleep(2)
sys.exit(2)
if args.lora and args.lora!="":
if not os.path.exists(args.lora):
print(f"Cannot find lora file: {args.lora}")
if args.lora and args.lora[0]!="":
if not os.path.exists(args.lora[0]):
print(f"Cannot find lora file: {args.lora[0]}")
time.sleep(2)
sys.exit(2)
else:
args.lora = os.path.abspath(args.lora)
args.lora[0] = os.path.abspath(args.lora[0])
if len(args.lora) > 1:
if not os.path.exists(args.lora[1]):
print(f"Cannot find lora base: {args.lora[1]}")
time.sleep(2)
sys.exit(2)
else:
args.lora[1] = os.path.abspath(args.lora[1])
if args.psutil_set_threads:
import psutil
@ -807,7 +819,7 @@ if __name__ == '__main__':
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
physical_core_limit = 1
if os.cpu_count()!=None and os.cpu_count()>1:
physical_core_limit = int(os.cpu_count()/2)