mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
add flash attention toggle
This commit is contained in:
parent
17a24d753c
commit
c65448d17a
3 changed files with 10 additions and 2 deletions
1
expose.h
1
expose.h
|
@ -55,6 +55,7 @@ struct load_model_inputs
|
||||||
const int gpulayers = 0;
|
const int gpulayers = 0;
|
||||||
const float rope_freq_scale = 1.0f;
|
const float rope_freq_scale = 1.0f;
|
||||||
const float rope_freq_base = 10000.0f;
|
const float rope_freq_base = 10000.0f;
|
||||||
|
const bool flash_attention = false;
|
||||||
const float tensor_split[tensor_split_max];
|
const float tensor_split[tensor_split_max];
|
||||||
};
|
};
|
||||||
struct generation_inputs
|
struct generation_inputs
|
||||||
|
|
|
@ -785,12 +785,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
{
|
{
|
||||||
kcpp_params->n_ubatch = (kcpp_params->n_batch>1024?1024:kcpp_params->n_batch);
|
kcpp_params->n_ubatch = (kcpp_params->n_batch>1024?1024:kcpp_params->n_batch);
|
||||||
}
|
}
|
||||||
|
kcpp_params->flash_attn = inputs.flash_attention;
|
||||||
modelname = kcpp_params->model = inputs.model_filename;
|
modelname = kcpp_params->model = inputs.model_filename;
|
||||||
useSmartContext = inputs.use_smartcontext;
|
useSmartContext = inputs.use_smartcontext;
|
||||||
useContextShift = inputs.use_contextshift;
|
useContextShift = inputs.use_contextshift;
|
||||||
debugmode = inputs.debugmode;
|
debugmode = inputs.debugmode;
|
||||||
|
|
||||||
|
|
||||||
auto clamped_max_context_length = inputs.max_context_length;
|
auto clamped_max_context_length = inputs.max_context_length;
|
||||||
|
|
||||||
if(clamped_max_context_length>16384 &&
|
if(clamped_max_context_length>16384 &&
|
||||||
|
@ -1089,6 +1089,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_ctx_params.flash_attn = kcpp_params->flash_attn;
|
||||||
llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params);
|
llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params);
|
||||||
|
|
||||||
if (llama_ctx_v4 == NULL)
|
if (llama_ctx_v4 == NULL)
|
||||||
|
|
|
@ -56,6 +56,7 @@ class load_model_inputs(ctypes.Structure):
|
||||||
("gpulayers", ctypes.c_int),
|
("gpulayers", ctypes.c_int),
|
||||||
("rope_freq_scale", ctypes.c_float),
|
("rope_freq_scale", ctypes.c_float),
|
||||||
("rope_freq_base", ctypes.c_float),
|
("rope_freq_base", ctypes.c_float),
|
||||||
|
("flash_attention", ctypes.c_bool),
|
||||||
("tensor_split", ctypes.c_float * tensor_split_max)]
|
("tensor_split", ctypes.c_float * tensor_split_max)]
|
||||||
|
|
||||||
class generation_inputs(ctypes.Structure):
|
class generation_inputs(ctypes.Structure):
|
||||||
|
@ -372,6 +373,7 @@ def load_model(model_filename):
|
||||||
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
|
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
|
||||||
inputs.use_smartcontext = args.smartcontext
|
inputs.use_smartcontext = args.smartcontext
|
||||||
inputs.use_contextshift = (0 if args.noshift else 1)
|
inputs.use_contextshift = (0 if args.noshift else 1)
|
||||||
|
inputs.flash_attention = args.flashattention
|
||||||
inputs.blasbatchsize = args.blasbatchsize
|
inputs.blasbatchsize = args.blasbatchsize
|
||||||
inputs.forceversion = args.forceversion
|
inputs.forceversion = args.forceversion
|
||||||
inputs.gpulayers = args.gpulayers
|
inputs.gpulayers = args.gpulayers
|
||||||
|
@ -1662,6 +1664,7 @@ def show_new_gui():
|
||||||
contextshift = ctk.IntVar(value=1)
|
contextshift = ctk.IntVar(value=1)
|
||||||
remotetunnel = ctk.IntVar(value=0)
|
remotetunnel = ctk.IntVar(value=0)
|
||||||
smartcontext = ctk.IntVar()
|
smartcontext = ctk.IntVar()
|
||||||
|
flashattention = ctk.IntVar(value=0)
|
||||||
context_var = ctk.IntVar()
|
context_var = ctk.IntVar()
|
||||||
customrope_var = ctk.IntVar()
|
customrope_var = ctk.IntVar()
|
||||||
customrope_scale = ctk.StringVar(value="1.0")
|
customrope_scale = ctk.StringVar(value="1.0")
|
||||||
|
@ -2112,7 +2115,6 @@ def show_new_gui():
|
||||||
# context size
|
# context size
|
||||||
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=3,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
|
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 20, set=3,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
|
||||||
|
|
||||||
|
|
||||||
customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale,tooltip="For Linear RoPE scaling. RoPE frequency scale.")
|
customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale,tooltip="For Linear RoPE scaling. RoPE frequency scale.")
|
||||||
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base,tooltip="For NTK Aware Scaling. RoPE frequency base.")
|
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base,tooltip="For NTK Aware Scaling. RoPE frequency base.")
|
||||||
def togglerope(a,b,c):
|
def togglerope(a,b,c):
|
||||||
|
@ -2124,6 +2126,7 @@ def show_new_gui():
|
||||||
item.grid_forget()
|
item.grid_forget()
|
||||||
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
|
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
|
||||||
togglerope(1,1,1)
|
togglerope(1,1,1)
|
||||||
|
makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28,tooltiptxt="Enable flash attention for GGUF models.")
|
||||||
makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
|
makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
|
||||||
|
|
||||||
# Model Tab
|
# Model Tab
|
||||||
|
@ -2202,6 +2205,7 @@ def show_new_gui():
|
||||||
args.highpriority = highpriority.get()==1
|
args.highpriority = highpriority.get()==1
|
||||||
args.nommap = disablemmap.get()==1
|
args.nommap = disablemmap.get()==1
|
||||||
args.smartcontext = smartcontext.get()==1
|
args.smartcontext = smartcontext.get()==1
|
||||||
|
args.flashattention = flashattention.get()==1
|
||||||
args.noshift = contextshift.get()==0
|
args.noshift = contextshift.get()==0
|
||||||
args.remotetunnel = remotetunnel.get()==1
|
args.remotetunnel = remotetunnel.get()==1
|
||||||
args.foreground = keepforeground.get()==1
|
args.foreground = keepforeground.get()==1
|
||||||
|
@ -2286,6 +2290,7 @@ def show_new_gui():
|
||||||
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
|
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
|
||||||
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
|
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
|
||||||
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
|
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
|
||||||
|
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
|
||||||
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
|
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
|
||||||
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
|
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
|
||||||
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
|
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
|
||||||
|
@ -3322,5 +3327,6 @@ if __name__ == '__main__':
|
||||||
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
|
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
|
||||||
parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
|
parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
|
||||||
parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
|
parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
|
||||||
|
parser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true')
|
||||||
|
|
||||||
main(parser.parse_args(),start_server=True)
|
main(parser.parse_args(),start_server=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue