diff --git a/expose.h b/expose.h index 418224dd6..2242f58a0 100644 --- a/expose.h +++ b/expose.h @@ -44,6 +44,7 @@ struct load_model_inputs const char * mmproj_filename; const bool use_mmap; const bool use_mlock; + const bool use_smartcontext; const bool use_contextshift; const int clblast_info = 0; const int cublas_info = 0; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 7293124a9..91bf28aa0 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -92,6 +92,7 @@ static int current_llava_identifier = LLAVA_TOKEN_IDENTIFIER_A; static gpt_params * kcpp_params = nullptr; static int max_context_limit_at_load = 0; static int n_past = 0; +static bool useSmartContext = false; static bool useContextShift = false; static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall static std::string modelname; @@ -786,6 +787,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } kcpp_params->flash_attn = inputs.flash_attention; modelname = kcpp_params->model = inputs.model_filename; + useSmartContext = inputs.use_smartcontext; useContextShift = inputs.use_contextshift; debugmode = inputs.debugmode; @@ -1939,7 +1941,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } else { - bool triggersc = useContextShift; + bool triggersc = useSmartContext; if(useContextShift && (file_format == FileFormat::GGUF_GENERIC)) { PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx); diff --git a/koboldcpp.py b/koboldcpp.py index 9d318e9ac..3b7e377bb 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -45,6 +45,7 @@ class load_model_inputs(ctypes.Structure): ("mmproj_filename", ctypes.c_char_p), ("use_mmap", ctypes.c_bool), ("use_mlock", ctypes.c_bool), + ("use_smartcontext", ctypes.c_bool), ("use_contextshift", ctypes.c_bool), ("clblast_info", ctypes.c_int), ("cublas_info", ctypes.c_int), @@ -371,6 +372,7 @@ def load_model(model_filename): inputs.lora_base = args.lora[1].encode("UTF-8") inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8") + inputs.use_smartcontext = args.smartcontext inputs.use_contextshift = (0 if args.noshift else 1) inputs.flash_attention = args.flashattention inputs.blasbatchsize = args.blasbatchsize @@ -1673,6 +1675,7 @@ def show_new_gui(): contextshift = ctk.IntVar(value=1) remotetunnel = ctk.IntVar(value=0) + smartcontext = ctk.IntVar() flashattention = ctk.IntVar(value=0) context_var = ctk.IntVar() customrope_var = ctk.IntVar() @@ -1956,7 +1959,10 @@ def show_new_gui(): gpulayers_var.trace("w", changed_gpulayers) def togglectxshift(a,b,c): - pass + if contextshift.get()==0: + smartcontextbox.grid(row=1, column=0, padx=8, pady=1, stick="nw") + else: + smartcontextbox.grid_forget() def guibench(): args.benchmark = "stdout" @@ -2110,6 +2116,7 @@ def show_new_gui(): # Tokens Tab tokens_tab = tabcontent["Tokens"] # tokens checkboxes + smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.") makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift) togglectxshift(1,1,1) @@ -2206,6 +2213,7 @@ def show_new_gui(): args.launch = launchbrowser.get()==1 args.highpriority = highpriority.get()==1 args.nommap = disablemmap.get()==1 + args.smartcontext = smartcontext.get()==1 args.flashattention = flashattention.get()==1 args.noshift = contextshift.get()==0 args.remotetunnel = remotetunnel.get()==1 @@ -2301,6 +2309,7 @@ def show_new_gui(): launchbrowser.set(1 if "launch" in dict and dict["launch"] else 0) highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0) disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0) + smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0) flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0) contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1) remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0) @@ -2730,7 +2739,7 @@ def check_deprecation_warning(): if using_outdated_flags: print(f"\n=== !!! IMPORTANT WARNING !!! ===") print("You are using one or more OUTDATED config files or launch flags!") - print("The flags --smartcontext, --hordeconfig and --sdconfig have been DEPRECATED, and MAY be REMOVED in future!") + print("The flags --hordeconfig and --sdconfig have been DEPRECATED, and MAY be REMOVED in future!") print("They will still work for now, but you SHOULD switch to the updated flags instead, to avoid future issues!") print("New flags are: --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen --sdmodel --sdthreads --sdquant --sdclamped") print("For more information on these flags, please check --help") @@ -3393,11 +3402,7 @@ if __name__ == '__main__': advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="") advparser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true') advparser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0) - - deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!') - deprecatedgroup.add_argument("--smartcontext", help="Command is DEPRECATED and should NOT be used! Instead, use --noshift instead to toggle smartcontext off on old GGML models.", action='store_true') - deprecatedgroup.add_argument("--hordeconfig", help="Command is DEPRECATED and should NOT be used! Instead, use non-positional flags --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen instead.", nargs='+') - deprecatedgroup.add_argument("--sdconfig", help="Command is DEPRECATED and should NOT be used! Instead, use non-positional flags --sdmodel --sdthreads --sdquant --sdclamped instead.", nargs='+') + advparser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently. Not recommended.", action='store_true') hordeparsergroup = parser.add_argument_group('Horde Worker Commands') hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="") @@ -3412,4 +3417,8 @@ if __name__ == '__main__': sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true') sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true') + deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!') + deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+') + deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+') + main(parser.parse_args(),start_server=True)