added lora support

This commit is contained in:
Concedo 2023-04-22 12:29:38 +08:00
parent c454f8b848
commit 6e908c1792
4 changed files with 45 additions and 19 deletions

View file

@ -16,7 +16,7 @@ class load_model_inputs(ctypes.Structure):
("f16_kv", ctypes.c_bool),
("executable_path", ctypes.c_char_p),
("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int),
("lora_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("clblast_info", ctypes.c_int),
@ -89,17 +89,17 @@ def init_library():
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False,use_smartcontext=False,blasbatchsize=512):
def load_model(model_filename):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = batch_size
inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten
inputs.threads = threads
inputs.n_parts_overwrite = n_parts_overwrite
inputs.lora_filename = args.lora.encode("UTF-8")
inputs.batch_size = 8
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
inputs.threads = args.threads
inputs.f16_kv = True
inputs.use_mmap = use_mmap
inputs.use_smartcontext = use_smartcontext
inputs.blasbatchsize = blasbatchsize
inputs.use_mmap = (not args.nommap)
inputs.use_smartcontext = args.smartcontext
inputs.blasbatchsize = args.blasbatchsize
clblastids = 0
if args.useclblast:
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
@ -403,7 +403,7 @@ def main(args):
embedded_kailite = None
ggml_selected_file = args.model_param
if not ggml_selected_file:
ggml_selected_file = args.model
ggml_selected_file = args.model
if not ggml_selected_file:
#give them a chance to pick a file
print("For command line arguments, please refer to --help")
@ -430,10 +430,17 @@ def main(args):
time.sleep(2)
sys.exit(2)
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
if args.lora and args.lora!="":
if not os.path.exists(args.lora):
print(f"Cannot find lora file: {args.lora}")
time.sleep(2)
sys.exit(2)
else:
args.lora = os.path.abspath(args.lora)
modelname = os.path.abspath(ggml_selected_file)
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}, SmartContext: {args.smartcontext}]")
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap),args.smartcontext,args.blasbatchsize)
print(f"Loading model: {modelname} \n[Threads: {args.threads}, SmartContext: {args.smartcontext}]")
loadok = load_model(modelname)
print("Load Model OK: " + str(loadok))
if not loadok:
@ -477,7 +484,8 @@ if __name__ == '__main__':
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
#os.environ["OMP_NUM_THREADS"] = '12'
# psutil.cpu_count(logical=False)
physical_core_limit = 1