Merged the upstream updates for model loading code, and ditched the legacy llama loaders since they were no longer needed.

This commit is contained in:
Concedo 2023-04-10 12:00:34 +08:00
commit f53238f570
20 changed files with 1234 additions and 1446 deletions

View file

@ -14,7 +14,8 @@ class load_model_inputs(ctypes.Structure):
("batch_size", ctypes.c_int),
("f16_kv", ctypes.c_bool),
("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int)]
("n_parts_overwrite", ctypes.c_int),
("use_mmap", ctypes.c_bool)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
@ -53,7 +54,7 @@ def init_library():
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6):
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = batch_size
@ -61,6 +62,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr
inputs.threads = threads
inputs.n_parts_overwrite = n_parts_overwrite
inputs.f16_kv = True
inputs.use_mmap = use_mmap
ret = handle.load_model(inputs)
return ret
@ -347,7 +349,7 @@ def main(args):
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
modelname = os.path.abspath(ggml_selected_file)
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}]")
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads)
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,args.usemmap)
print("Load Model OK: " + str(loadok))
if not loadok:
@ -369,7 +371,7 @@ def main(args):
if args.host=="":
epurl = f"http://localhost:{args.port}" + ("?streaming=1" if args.stream else "")
else:
epurl = f"http://{args.host}:{args.port}" + ("&streaming=1" if args.stream else "")
epurl = f"http://{args.host}:{args.port}" + ("?streaming=1" if args.stream else "")
print(f"Please connect to custom endpoint at {epurl}")
@ -394,5 +396,6 @@ if __name__ == '__main__':
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true')
parser.add_argument("--noblas", help="Do not use OpenBLAS for accelerated prompt ingestion", action='store_true')
parser.add_argument("--usemmap", help="Use mmap to load newer models (default false)", action='store_true')
args = parser.parse_args()
main(args)