diff --git a/expose.h b/expose.h index e922d3f07..f471cc417 100644 --- a/expose.h +++ b/expose.h @@ -45,6 +45,7 @@ struct load_model_inputs const bool use_mlock = false; const bool use_smartcontext = false; const bool use_contextshift = false; + const bool use_fastforward = false; const int clblast_info = 0; const int cublas_info = 0; const char * vulkan_info = nullptr; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 5c20b2509..e7bb8351b 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1688,6 +1688,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in kcpp_data->model_filename = inputs.model_filename; kcpp_data->use_smartcontext = inputs.use_smartcontext; kcpp_data->use_contextshift = inputs.use_contextshift; + kcpp_data->use_fastforward = inputs.use_fastforward; debugmode = inputs.debugmode; auto clamped_max_context_length = inputs.max_context_length; @@ -2951,7 +2952,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) { if(!blank_prompt) { - ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true); + if(kcpp_data->use_fastforward) + { + ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true); + } } if(is_mamba || is_rwkv_new) { @@ -2971,12 +2975,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs) bool triggersc = kcpp_data->use_smartcontext; if(!blank_prompt) //special case for blank prompts, no fast forward or shifts { - if(kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC)) + if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC)) { PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx); triggersc = false; } - ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false); + if(kcpp_data->use_fastforward) + { + ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false); + } } if(file_format == FileFormat::GGUF_GENERIC) { diff --git a/koboldcpp.py b/koboldcpp.py index ca9992d58..a4fa47bf8 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -129,6 +129,7 @@ class load_model_inputs(ctypes.Structure): ("use_mlock", ctypes.c_bool), ("use_smartcontext", ctypes.c_bool), ("use_contextshift", ctypes.c_bool), + ("use_fastforward", ctypes.c_bool), ("clblast_info", ctypes.c_int), ("cublas_info", ctypes.c_int), ("vulkan_info", ctypes.c_char_p), @@ -869,6 +870,7 @@ def load_model(model_filename): inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8") inputs.use_smartcontext = args.smartcontext inputs.use_contextshift = (0 if args.noshift else 1) + inputs.use_fastforward = (0 if args.nofastforward else 1) inputs.flash_attention = args.flashattention if args.quantkv>0: inputs.quant_k = inputs.quant_v = args.quantkv @@ -2494,6 +2496,7 @@ def show_gui(): rowsplit_var = ctk.IntVar() contextshift = ctk.IntVar(value=1) + fastforward = ctk.IntVar(value=1) remotetunnel = ctk.IntVar(value=0) smartcontext = ctk.IntVar() flashattention = ctk.IntVar(value=0) @@ -2739,10 +2742,16 @@ def show_gui(): gpu_choice_var.trace("w", changed_gpu_choice_var) gpulayers_var.trace("w", changed_gpulayers_estimate) + def togglefastforward(a,b,c): + if fastforward.get()==0: + contextshift.set(0) + togglectxshift(1,1,1) + def togglectxshift(a,b,c): if contextshift.get()==0: smartcontextbox.grid() else: + fastforward.set(1) smartcontextbox.grid_remove() if contextshift.get()==0 and flashattention.get()==1: @@ -2951,7 +2960,7 @@ def show_gui(): # tokens checkboxes smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.") makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift) - + makecheckbox(tokens_tab, "Use FastForwarding", fastforward, 3,tooltiptxt="Use fast forwarding to recycle previous context (always reprocess if disabled).\nRecommended.", command=togglefastforward) # context size makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, width=280, set=5,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.") @@ -3151,6 +3160,7 @@ def show_gui(): args.smartcontext = smartcontext.get()==1 args.flashattention = flashattention.get()==1 args.noshift = contextshift.get()==0 + args.nofastforward = fastforward.get()==0 args.remotetunnel = remotetunnel.get()==1 args.foreground = keepforeground.get()==1 args.quiet = quietmode.get()==1 @@ -3300,6 +3310,7 @@ def show_gui(): smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0) flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0) contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1) + fastforward.set(0 if "nofastforward" in dict and dict["nofastforward"] else 1) remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0) keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0) quietmode.set(1 if "quiet" in dict and dict["quiet"] else 0) @@ -4645,6 +4656,7 @@ if __name__ == '__main__': advparser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0) advparser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+') advparser.add_argument("--noshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true') + advparser.add_argument("--nofastforward", help="If set, do not attempt to fast forward GGUF context (always reprocess). Will also enable noshift", action='store_true') advparser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true') advparser.add_argument("--usemlock", help="Enables mlock, preventing the RAM used to load the model from being paged out. Not usually recommended.", action='store_true') advparser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices.", action='store_true') diff --git a/otherarch/otherarch.h b/otherarch/otherarch.h index a69669453..07edf3904 100644 --- a/otherarch/otherarch.h +++ b/otherarch/otherarch.h @@ -54,6 +54,7 @@ struct kcpp_params { bool flash_attn = false; // flash attention bool use_smartcontext = false; bool use_contextshift = false; + bool use_fastforward = false; }; // default hparams (GPT-J 6B)