add flash attention toggle

This commit is contained in:
Concedo 2024-04-30 21:29:11 +08:00
parent 17a24d753c
commit c65448d17a
3 changed files with 10 additions and 2 deletions

View file

@ -55,6 +55,7 @@ struct load_model_inputs
const int gpulayers = 0;
const float rope_freq_scale = 1.0f;
const float rope_freq_base = 10000.0f;
const bool flash_attention = false;
const float tensor_split[tensor_split_max];
};
struct generation_inputs

View file

@ -785,12 +785,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
kcpp_params->n_ubatch = (kcpp_params->n_batch>1024?1024:kcpp_params->n_batch);
}
kcpp_params->flash_attn = inputs.flash_attention;
modelname = kcpp_params->model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift;
debugmode = inputs.debugmode;
auto clamped_max_context_length = inputs.max_context_length;
if(clamped_max_context_length>16384 &&
@ -1089,6 +1089,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
llama_ctx_params.flash_attn = kcpp_params->flash_attn;
llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params);
if (llama_ctx_v4 == NULL)

View file

@ -56,6 +56,7 @@ class load_model_inputs(ctypes.Structure):
("gpulayers", ctypes.c_int),
("rope_freq_scale", ctypes.c_float),
("rope_freq_base", ctypes.c_float),
("flash_attention", ctypes.c_bool),
("tensor_split", ctypes.c_float * tensor_split_max)]
class generation_inputs(ctypes.Structure):
@ -372,6 +373,7 @@ def load_model(model_filename):
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.flash_attention = args.flashattention
inputs.blasbatchsize = args.blasbatchsize
inputs.forceversion = args.forceversion
inputs.gpulayers = args.gpulayers
@ -1662,6 +1664,7 @@ def show_new_gui():
contextshift = ctk.IntVar(value=1)
remotetunnel = ctk.IntVar(value=0)
smartcontext = ctk.IntVar()
flashattention = ctk.IntVar(value=0)
context_var = ctk.IntVar()
customrope_var = ctk.IntVar()
customrope_scale = ctk.StringVar(value="1.0")
@ -2112,7 +2115,6 @@ def show_new_gui():
# context size
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=3,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale,tooltip="For Linear RoPE scaling. RoPE frequency scale.")
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base,tooltip="For NTK Aware Scaling. RoPE frequency base.")
def togglerope(a,b,c):
@ -2124,6 +2126,7 @@ def show_new_gui():
item.grid_forget()
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
togglerope(1,1,1)
makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28,tooltiptxt="Enable flash attention for GGUF models.")
makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
# Model Tab
@ -2202,6 +2205,7 @@ def show_new_gui():
args.highpriority = highpriority.get()==1
args.nommap = disablemmap.get()==1
args.smartcontext = smartcontext.get()==1
args.flashattention = flashattention.get()==1
args.noshift = contextshift.get()==0
args.remotetunnel = remotetunnel.get()==1
args.foreground = keepforeground.get()==1
@ -2286,6 +2290,7 @@ def show_new_gui():
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
@ -3322,5 +3327,6 @@ if __name__ == '__main__':
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
parser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true')
main(parser.parse_args(),start_server=True)