Removed junk, fixed some bugs and support dynamic number of sharded files

Merge remote-tracking branch 'origin/master' into concedo

# Conflicts:
#	README.md
This commit is contained in:
Concedo 2023-03-19 11:13:00 +08:00
commit f952b7c613
14 changed files with 40 additions and 312 deletions

View file

@ -10,7 +10,8 @@ class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int),
("max_context_length", ctypes.c_int),
("batch_size", ctypes.c_int),
("model_filename", ctypes.c_char_p)]
("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
@ -27,19 +28,20 @@ class generation_outputs(ctypes.Structure):
("text", ctypes.c_char * 16384)]
dir_path = os.path.dirname(os.path.realpath(__file__))
handle = ctypes.CDLL(dir_path + "/llamalib.dll")
handle = ctypes.CDLL(dir_path + "/llamacpp.dll")
handle.load_model.argtypes = [load_model_inputs]
handle.load_model.restype = ctypes.c_bool
handle.generate.argtypes = [generation_inputs]
handle.generate.restype = generation_outputs
def load_model(model_filename,batch_size=8,max_context_length=512,threads=4):
def load_model(model_filename,batch_size=8,max_context_length=512,threads=4,n_parts_overwrite=-1):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = batch_size
inputs.max_context_length = max_context_length
inputs.threads = threads
inputs.n_parts_overwrite = n_parts_overwrite
ret = handle.load_model(inputs)
return ret
@ -233,9 +235,13 @@ if __name__ == '__main__':
print("Cannot find model file: " + sys.argv[1])
exit()
mdl_nparts = 1
for n in range(1,9):
if os.path.exists(sys.argv[1]+"."+str(n)):
mdl_nparts += 1
modelname = os.path.abspath(sys.argv[1])
print("Loading model: " + modelname)
loadok = load_model(modelname,128,maxctx,4)
loadok = load_model(modelname,128,maxctx,4,mdl_nparts)
print("Load Model OK: " + str(loadok))
if loadok: