context shift feature done

This commit is contained in:
Concedo 2023-10-29 18:21:39 +08:00
parent 338d6c265d
commit 7924592a83
4 changed files with 41 additions and 18 deletions

View file

@ -34,6 +34,7 @@ class load_model_inputs(ctypes.Structure):
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("use_contextshift", ctypes.c_bool),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
("blasbatchsize", ctypes.c_int),
@ -227,6 +228,7 @@ def load_model(model_filename):
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (not args.nocontextshift)
inputs.blasbatchsize = args.blasbatchsize
inputs.forceversion = args.forceversion
inputs.gpulayers = args.gpulayers
@ -1045,6 +1047,7 @@ def show_new_gui():
version_var = ctk.StringVar(value="0")
tensor_split_str_vars = ctk.StringVar(value="")
contextshift = ctk.IntVar(value=1)
smartcontext = ctk.IntVar()
context_var = ctk.IntVar()
customrope_var = ctk.IntVar()
@ -1142,7 +1145,7 @@ def show_new_gui():
makeslider(quick_tab, "BLAS Batch Size:", blasbatchsize_text, blas_size_var, 0, 7, 12, set=5)
# quick boxes
quick_boxes = {"Launch Browser": launchbrowser , "High Priority" : highpriority, "Use SmartContext":smartcontext, "Disable MMAP":disablemmap,}
quick_boxes = {"Launch Browser": launchbrowser , "High Priority" : highpriority, "Use SmartContext":smartcontext, "Disable MMAP":disablemmap,"Use ContextShift":contextshift}
for idx, name, in enumerate(quick_boxes):
makecheckbox(quick_tab, name, quick_boxes[name], int(idx/2) +20, idx%2)
# context size
@ -1194,7 +1197,7 @@ def show_new_gui():
# Tokens Tab
tokens_tab = tabcontent["Tokens"]
# tokens checkboxes
token_boxes = {"Use SmartContext":smartcontext}
token_boxes = {"Use SmartContext":smartcontext, "Use ContextShift":contextshift}
for idx, name, in enumerate(token_boxes):
makecheckbox(tokens_tab, name, token_boxes[name], idx + 1)
@ -1273,6 +1276,7 @@ def show_new_gui():
args.highpriority = highpriority.get()==1
args.nommap = disablemmap.get()==1
args.smartcontext = smartcontext.get()==1
args.nocontextshift = contextshift.get()==0
args.foreground = keepforeground.get()==1
gpuchoiceidx = 0
@ -1336,6 +1340,7 @@ def show_new_gui():
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
contextshift.set(0 if "nocontextshift" in dict and dict["nocontextshift"] else 1)
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
if "useclblast" in dict and dict["useclblast"]:
if clblast_option is not None:
@ -1833,7 +1838,7 @@ def main(launch_args,start_server=True):
modelname = os.path.abspath(args.model_param)
print(args)
print(f"==========\nLoading model: {modelname} \n[Threads: {args.threads}, BlasThreads: {args.blasthreads}, SmartContext: {args.smartcontext}]")
print(f"==========\nLoading model: {modelname} \n[Threads: {args.threads}, BlasThreads: {args.blasthreads}, SmartContext: {args.smartcontext}, ContextShift: {not (args.nocontextshift)}]")
loadok = load_model(modelname)
print("Load Model OK: " + str(loadok))
@ -1917,6 +1922,7 @@ if __name__ == '__main__':
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512). Setting it to -1 disables BLAS mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,32,64,128,256,512,1024,2048], default=512)
parser.add_argument("--ropeconfig", help="If set, uses customized RoPE scaling from configured frequency scale and frequency base (e.g. --ropeconfig 0.25 10000). Otherwise, uses NTK-Aware scaling set automatically based on context size. For linear rope, simply set the freq-scale and ignore the freq-base",metavar=('[rope-freq-scale]', '[rope-freq-base]'), default=[0.0, 10000.0], type=float, nargs='+')
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
parser.add_argument("--nocontextshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true')
parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+')
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')