mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
add flash attention toggle
This commit is contained in:
parent
17a24d753c
commit
c65448d17a
3 changed files with 10 additions and 2 deletions
1
expose.h
1
expose.h
|
@ -55,6 +55,7 @@ struct load_model_inputs
|
|||
const int gpulayers = 0;
|
||||
const float rope_freq_scale = 1.0f;
|
||||
const float rope_freq_base = 10000.0f;
|
||||
const bool flash_attention = false;
|
||||
const float tensor_split[tensor_split_max];
|
||||
};
|
||||
struct generation_inputs
|
||||
|
|
|
@ -785,12 +785,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
{
|
||||
kcpp_params->n_ubatch = (kcpp_params->n_batch>1024?1024:kcpp_params->n_batch);
|
||||
}
|
||||
kcpp_params->flash_attn = inputs.flash_attention;
|
||||
modelname = kcpp_params->model = inputs.model_filename;
|
||||
useSmartContext = inputs.use_smartcontext;
|
||||
useContextShift = inputs.use_contextshift;
|
||||
debugmode = inputs.debugmode;
|
||||
|
||||
|
||||
auto clamped_max_context_length = inputs.max_context_length;
|
||||
|
||||
if(clamped_max_context_length>16384 &&
|
||||
|
@ -1089,6 +1089,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
}
|
||||
|
||||
llama_ctx_params.flash_attn = kcpp_params->flash_attn;
|
||||
llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params);
|
||||
|
||||
if (llama_ctx_v4 == NULL)
|
||||
|
|
|
@ -56,6 +56,7 @@ class load_model_inputs(ctypes.Structure):
|
|||
("gpulayers", ctypes.c_int),
|
||||
("rope_freq_scale", ctypes.c_float),
|
||||
("rope_freq_base", ctypes.c_float),
|
||||
("flash_attention", ctypes.c_bool),
|
||||
("tensor_split", ctypes.c_float * tensor_split_max)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
|
@ -372,6 +373,7 @@ def load_model(model_filename):
|
|||
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.use_contextshift = (0 if args.noshift else 1)
|
||||
inputs.flash_attention = args.flashattention
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
inputs.forceversion = args.forceversion
|
||||
inputs.gpulayers = args.gpulayers
|
||||
|
@ -1662,6 +1664,7 @@ def show_new_gui():
|
|||
contextshift = ctk.IntVar(value=1)
|
||||
remotetunnel = ctk.IntVar(value=0)
|
||||
smartcontext = ctk.IntVar()
|
||||
flashattention = ctk.IntVar(value=0)
|
||||
context_var = ctk.IntVar()
|
||||
customrope_var = ctk.IntVar()
|
||||
customrope_scale = ctk.StringVar(value="1.0")
|
||||
|
@ -2112,7 +2115,6 @@ def show_new_gui():
|
|||
# context size
|
||||
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=3,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
|
||||
|
||||
|
||||
customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale,tooltip="For Linear RoPE scaling. RoPE frequency scale.")
|
||||
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base,tooltip="For NTK Aware Scaling. RoPE frequency base.")
|
||||
def togglerope(a,b,c):
|
||||
|
@ -2124,6 +2126,7 @@ def show_new_gui():
|
|||
item.grid_forget()
|
||||
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
|
||||
togglerope(1,1,1)
|
||||
makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28,tooltiptxt="Enable flash attention for GGUF models.")
|
||||
makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
|
||||
|
||||
# Model Tab
|
||||
|
@ -2202,6 +2205,7 @@ def show_new_gui():
|
|||
args.highpriority = highpriority.get()==1
|
||||
args.nommap = disablemmap.get()==1
|
||||
args.smartcontext = smartcontext.get()==1
|
||||
args.flashattention = flashattention.get()==1
|
||||
args.noshift = contextshift.get()==0
|
||||
args.remotetunnel = remotetunnel.get()==1
|
||||
args.foreground = keepforeground.get()==1
|
||||
|
@ -2286,6 +2290,7 @@ def show_new_gui():
|
|||
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
|
||||
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
|
||||
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
|
||||
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
|
||||
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
|
||||
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
|
||||
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
|
||||
|
@ -3322,5 +3327,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
|
||||
parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
|
||||
parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
|
||||
parser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true')
|
||||
|
||||
main(parser.parse_args(),start_server=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue