mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
Removed junk, fixed some bugs and support dynamic number of sharded files
Merge remote-tracking branch 'origin/master' into concedo # Conflicts: # README.md
This commit is contained in:
commit
f952b7c613
14 changed files with 40 additions and 312 deletions
|
@ -10,7 +10,8 @@ class load_model_inputs(ctypes.Structure):
|
|||
_fields_ = [("threads", ctypes.c_int),
|
||||
("max_context_length", ctypes.c_int),
|
||||
("batch_size", ctypes.c_int),
|
||||
("model_filename", ctypes.c_char_p)]
|
||||
("model_filename", ctypes.c_char_p),
|
||||
("n_parts_overwrite", ctypes.c_int)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
|
@ -27,19 +28,20 @@ class generation_outputs(ctypes.Structure):
|
|||
("text", ctypes.c_char * 16384)]
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
handle = ctypes.CDLL(dir_path + "/llamalib.dll")
|
||||
handle = ctypes.CDLL(dir_path + "/llamacpp.dll")
|
||||
|
||||
handle.load_model.argtypes = [load_model_inputs]
|
||||
handle.load_model.restype = ctypes.c_bool
|
||||
handle.generate.argtypes = [generation_inputs]
|
||||
handle.generate.restype = generation_outputs
|
||||
|
||||
def load_model(model_filename,batch_size=8,max_context_length=512,threads=4):
|
||||
def load_model(model_filename,batch_size=8,max_context_length=512,threads=4,n_parts_overwrite=-1):
|
||||
inputs = load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.batch_size = batch_size
|
||||
inputs.max_context_length = max_context_length
|
||||
inputs.threads = threads
|
||||
inputs.n_parts_overwrite = n_parts_overwrite
|
||||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
|
@ -233,9 +235,13 @@ if __name__ == '__main__':
|
|||
print("Cannot find model file: " + sys.argv[1])
|
||||
exit()
|
||||
|
||||
mdl_nparts = 1
|
||||
for n in range(1,9):
|
||||
if os.path.exists(sys.argv[1]+"."+str(n)):
|
||||
mdl_nparts += 1
|
||||
modelname = os.path.abspath(sys.argv[1])
|
||||
print("Loading model: " + modelname)
|
||||
loadok = load_model(modelname,128,maxctx,4)
|
||||
loadok = load_model(modelname,128,maxctx,4,mdl_nparts)
|
||||
print("Load Model OK: " + str(loadok))
|
||||
|
||||
if loadok:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue