diff --git a/expose.h b/expose.h index 1ef865e73..04565a357 100644 --- a/expose.h +++ b/expose.h @@ -55,6 +55,7 @@ struct load_model_inputs const int gpulayers = 0; const float rope_freq_scale = 1.0f; const float rope_freq_base = 10000.0f; + const bool flash_attention = false; const float tensor_split[tensor_split_max]; }; struct generation_inputs diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 422abcb42..90bd2dd05 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -785,12 +785,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in { kcpp_params->n_ubatch = (kcpp_params->n_batch>1024?1024:kcpp_params->n_batch); } + kcpp_params->flash_attn = inputs.flash_attention; modelname = kcpp_params->model = inputs.model_filename; useSmartContext = inputs.use_smartcontext; useContextShift = inputs.use_contextshift; debugmode = inputs.debugmode; - auto clamped_max_context_length = inputs.max_context_length; if(clamped_max_context_length>16384 && @@ -1089,6 +1089,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } } + llama_ctx_params.flash_attn = kcpp_params->flash_attn; llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params); if (llama_ctx_v4 == NULL) diff --git a/koboldcpp.py b/koboldcpp.py index 3f82bd3c4..0c404a08f 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -56,6 +56,7 @@ class load_model_inputs(ctypes.Structure): ("gpulayers", ctypes.c_int), ("rope_freq_scale", ctypes.c_float), ("rope_freq_base", ctypes.c_float), + ("flash_attention", ctypes.c_bool), ("tensor_split", ctypes.c_float * tensor_split_max)] class generation_inputs(ctypes.Structure): @@ -372,6 +373,7 @@ def load_model(model_filename): inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8") inputs.use_smartcontext = args.smartcontext inputs.use_contextshift = (0 if args.noshift else 1) + inputs.flash_attention = args.flashattention inputs.blasbatchsize = args.blasbatchsize inputs.forceversion = args.forceversion inputs.gpulayers = args.gpulayers @@ -1662,6 +1664,7 @@ def show_new_gui(): contextshift = ctk.IntVar(value=1) remotetunnel = ctk.IntVar(value=0) smartcontext = ctk.IntVar() + flashattention = ctk.IntVar(value=0) context_var = ctk.IntVar() customrope_var = ctk.IntVar() customrope_scale = ctk.StringVar(value="1.0") @@ -2112,7 +2115,6 @@ def show_new_gui(): # context size makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=3,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.") - customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale,tooltip="For Linear RoPE scaling. RoPE frequency scale.") customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base,tooltip="For NTK Aware Scaling. RoPE frequency base.") def togglerope(a,b,c): @@ -2124,6 +2126,7 @@ def show_new_gui(): item.grid_forget() makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.") togglerope(1,1,1) + makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28,tooltiptxt="Enable flash attention for GGUF models.") makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.") # Model Tab @@ -2202,6 +2205,7 @@ def show_new_gui(): args.highpriority = highpriority.get()==1 args.nommap = disablemmap.get()==1 args.smartcontext = smartcontext.get()==1 + args.flashattention = flashattention.get()==1 args.noshift = contextshift.get()==0 args.remotetunnel = remotetunnel.get()==1 args.foreground = keepforeground.get()==1 @@ -2286,6 +2290,7 @@ def show_new_gui(): highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0) disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0) smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0) + flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0) contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1) remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0) keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0) @@ -3322,5 +3327,6 @@ if __name__ == '__main__': parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None) parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true') parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="") + parser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true') main(parser.parse_args(),start_server=True)