Added SmartContext mode, a way of prompt context manipulation that avoids frequent context recalculation.

This commit is contained in:
Concedo 2023-04-14 21:24:16 +08:00
parent ca297c190f
commit adb4df78d6
6 changed files with 254 additions and 51 deletions

View file

@ -16,6 +16,7 @@ class load_model_inputs(ctypes.Structure):
("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int),
("use_mmap", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("clblast_info", ctypes.c_int)]
class generation_inputs(ctypes.Structure):
@ -65,7 +66,7 @@ def init_library():
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False):
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False,use_smartcontext=False):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = batch_size
@ -74,6 +75,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr
inputs.n_parts_overwrite = n_parts_overwrite
inputs.f16_kv = True
inputs.use_mmap = use_mmap
inputs.use_smartcontext = use_smartcontext
clblastids = 0
if args.useclblast:
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
@ -383,8 +385,8 @@ def main(args):
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
modelname = os.path.abspath(ggml_selected_file)
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}]")
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap))
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}, SmartContext: {args.smartcontext}]")
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap),args.smartcontext)
print("Load Model OK: " + str(loadok))
if not loadok:
@ -413,7 +415,7 @@ def main(args):
RunServerMultiThreaded(args.host, args.port, embedded_kailite)
if __name__ == '__main__':
print("Welcome to KoboldCpp - Version 1.6") # just update version manually
print("Welcome to KoboldCpp - Version 1.7") # just update version manually
parser = argparse.ArgumentParser(description='Kobold llama.cpp server')
parser.add_argument("model_file", help="Model file to load", nargs="?")
portgroup = parser.add_mutually_exclusive_group() #we want to be backwards compatible with the unnamed positional args
@ -430,6 +432,7 @@ if __name__ == '__main__':
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true')
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
compatgroup = parser.add_mutually_exclusive_group()