mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-10 12:11:08 +00:00
make flash attention default in cli. added --noflashattention
This commit is contained in:
parent
d877ec7aec
commit
ef7fe1b5d4
1 changed files with 16 additions and 13 deletions
29
koboldcpp.py
29
koboldcpp.py
|
|
@ -1576,14 +1576,14 @@ def load_model(model_filename):
|
|||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.use_contextshift = (0 if args.noshift else 1)
|
||||
inputs.use_fastforward = (0 if args.nofastforward else 1)
|
||||
inputs.flash_attention = args.flashattention
|
||||
inputs.flash_attention = (False if args.noflashattention else True)
|
||||
if args.quantkv>0:
|
||||
if args.flashattention:
|
||||
inputs.quant_k = inputs.quant_v = args.quantkv
|
||||
else:
|
||||
if args.noflashattention:
|
||||
inputs.quant_k = args.quantkv
|
||||
inputs.quant_v = 0
|
||||
print("\nWarning: quantkv was used without flashattention! This is NOT RECOMMENDED!\nOnly K cache can be quantized, and performance can suffer.\nIn some cases, it might even use more VRAM when doing a full offload.\nYou are strongly encouraged to use flashattention if you want to use quantkv.")
|
||||
print("\nWarning: Quantized KV was used without flash attention! This is NOT RECOMMENDED!\nOnly K cache can be quantized, and performance can suffer.\nIn some cases, it might even use more VRAM when doing a full offload.\nYou are strongly encouraged to use flash attention if you want to use quantkv.")
|
||||
else:
|
||||
inputs.quant_k = inputs.quant_v = args.quantkv
|
||||
else:
|
||||
inputs.quant_k = inputs.quant_v = 0
|
||||
inputs.batchsize = args.batchsize
|
||||
|
|
@ -2179,7 +2179,7 @@ def tts_load_model(ttc_model_filename,cts_model_filename):
|
|||
inputs.ttc_model_filename = ttc_model_filename.encode("UTF-8") if ttc_model_filename else "".encode("UTF-8")
|
||||
inputs.cts_model_filename = cts_model_filename.encode("UTF-8") if cts_model_filename else "".encode("UTF-8")
|
||||
inputs.gpulayers = (999 if args.ttsgpu else 0)
|
||||
inputs.flash_attention = args.flashattention
|
||||
inputs.flash_attention = (False if args.noflashattention else True)
|
||||
thds = args.threads
|
||||
if args.ttsthreads and args.ttsthreads > 0:
|
||||
ttst = int(args.ttsthreads)
|
||||
|
|
@ -2259,7 +2259,7 @@ def embeddings_load_model(model_filename):
|
|||
inputs = embeddings_load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.gpulayers = (999 if args.embeddingsgpu else 0)
|
||||
inputs.flash_attention = args.flashattention
|
||||
inputs.flash_attention = (False if args.noflashattention else True)
|
||||
inputs.threads = args.threads
|
||||
inputs.use_mmap = args.usemmap
|
||||
inputs.embeddingsmaxctx = (args.embeddingsmaxctx if args.embeddingsmaxctx else args.contextsize) # for us to clamp to contextsize if embeddingsmaxctx unspecified
|
||||
|
|
@ -6173,7 +6173,7 @@ def show_gui():
|
|||
makecheckbox(context_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
|
||||
noqkvlabel = makelabel(context_tab,"(Note: QuantKV works best with flash attention)",30,0,"Only K cache can be quantized, and performance can suffer.\nIn some cases, it might even use more VRAM when doing a full offload.",padx=160)
|
||||
noqkvlabel.configure(text_color="#ff5555")
|
||||
qkvslider,qkvlabel,qkvtitle = makeslider(context_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 2, 30, set=0,tooltip="Enable quantization of KV cache.\nRequires FlashAttention for full effect, otherwise only K cache is quantized.")
|
||||
qkvslider,qkvlabel,qkvtitle = makeslider(context_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 2, 30, set=0,tooltip="Enable quantization of KV cache.\nRequires Flash Attention for full effect, otherwise only K cache is quantized.")
|
||||
quantkv_var.trace_add("write", toggleflashattn)
|
||||
makecheckbox(context_tab, "No BOS Token", nobostoken_var, 43, tooltiptxt="Prevents BOS token from being added at the start of any prompt. Usually NOT recommended for most models.")
|
||||
makecheckbox(context_tab, "Enable Guidance", enableguidance_var, 43,padx=(140), tooltiptxt="Enables the use of Classifier-Free-Guidance, which allows the use of negative prompts. Has performance and memory impact.")
|
||||
|
|
@ -6434,7 +6434,7 @@ def show_gui():
|
|||
args.highpriority = highpriority.get()==1
|
||||
args.usemmap = usemmap.get()==1
|
||||
args.smartcontext = smartcontext_var.get()==1
|
||||
args.flashattention = flashattention_var.get()==1
|
||||
args.noflashattention = flashattention_var.get()==0
|
||||
args.noshift = contextshift_var.get()==0
|
||||
args.nofastforward = fastforward_var.get()==0
|
||||
args.useswa = swa_var.get()==1
|
||||
|
|
@ -6657,7 +6657,7 @@ def show_gui():
|
|||
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
|
||||
usemmap.set(1 if "usemmap" in dict and dict["usemmap"] else 0)
|
||||
smartcontext_var.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
|
||||
flashattention_var.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
|
||||
flashattention_var.set(0 if "noflashattention" in dict and dict["noflashattention"] else 1)
|
||||
contextshift_var.set(0 if "noshift" in dict and dict["noshift"] else 1)
|
||||
fastforward_var.set(0 if "nofastforward" in dict and dict["nofastforward"] else 1)
|
||||
swa_var.set(1 if "useswa" in dict and dict["useswa"] else 0)
|
||||
|
|
@ -7244,6 +7244,8 @@ def convert_invalid_args(args):
|
|||
dict["jinja"] = True
|
||||
if "sdgendefaults" in dict and "gendefaults" not in dict:
|
||||
dict["gendefaults"] = dict["sdgendefaults"]
|
||||
if "flashattention" in dict and "noflashattention" not in dict:
|
||||
dict["noflashattention"] = not dict["flashattention"]
|
||||
return args
|
||||
|
||||
def setuptunnel(global_memory, has_sd):
|
||||
|
|
@ -8190,7 +8192,7 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
if args.gpulayers==-1:
|
||||
if MaxMemory[0] > 0 and (not args.usecpu) and ((args.usecuda is not None) or (args.usevulkan is not None) or sys.platform=="darwin"):
|
||||
extract_modelfile_params(args.model_param,args.sdmodel,args.whispermodel,args.mmproj,args.draftmodel,args.ttsmodel if args.ttsgpu else "",args.embeddingsmodel if args.embeddingsgpu else "")
|
||||
layeramt = autoset_gpu_layers(args.contextsize,args.sdquant,args.batchsize,(args.quantkv if args.flashattention else 0))
|
||||
layeramt = autoset_gpu_layers(args.contextsize,args.sdquant,args.batchsize,(0 if args.noflashattention else args.quantkv))
|
||||
print(f"Auto Recommended GPU Layers: {layeramt}")
|
||||
args.gpulayers = layeramt
|
||||
else:
|
||||
|
|
@ -8643,7 +8645,7 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
s_pp = float(benchmaxctx-benchlen)/t_pp
|
||||
s_gen = float(benchlen)/t_gen
|
||||
datetimestamp = datetime.now(timezone.utc)
|
||||
benchflagstr = f"NoAVX2={args.noavx2} Threads={args.threads} HighPriority={args.highpriority} Cuda_Args={args.usecuda} Tensor_Split={args.tensor_split} BlasThreads={args.blasthreads} BatchSize={args.batchsize} FlashAttention={args.flashattention} KvCache={args.quantkv}"
|
||||
benchflagstr = f"NoAVX2={args.noavx2} Threads={args.threads} HighPriority={args.highpriority} Cuda_Args={args.usecuda} Tensor_Split={args.tensor_split} BlasThreads={args.blasthreads} BatchSize={args.batchsize} FlashAttention={not args.noflashattention} KvCache={args.quantkv}"
|
||||
print(f"\nBenchmark Completed - v{KcppVersion} Results:\n======")
|
||||
print(f"Flags: {benchflagstr}")
|
||||
print(f"Timestamp: {datetimestamp}")
|
||||
|
|
@ -8768,7 +8770,7 @@ if __name__ == '__main__':
|
|||
advparser.add_argument("--chatcompletionsadapter", metavar=('[filename]'), help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="AutoGuess")
|
||||
advparser.add_argument("--jinja", help="Enables using jinja chat template formatting for chat completions endpoint. Other endpoints are unaffected. Tool calls are done without jinja.", action='store_true')
|
||||
advparser.add_argument("--jinja_tools","--jinja-tools","--jinjatools", help="Enables using jinja chat template formatting for chat completions endpoint. Other endpoints are unaffected. Tool calls are done with jinja.", action='store_true')
|
||||
advparser.add_argument("--flashattention","--flash-attn","-fa", help="Enables flash attention.", action='store_true')
|
||||
advparser.add_argument("--noflashattention","--no-flash-attn","-nofa", help="Disables flash attention.", action='store_true')
|
||||
advparser.add_argument("--lowvram","-nkvo","--no-kv-offload", help="If supported by the backend, do not offload KV to GPU (lowvram mode). Not recommended, will be slow.", action='store_true')
|
||||
advparser.add_argument("--quantkv", help="Sets the KV cache data type quantization, 0=f16, 1=q8, 2=q4. Requires Flash Attention for full effect, otherwise only K cache is quantized.",metavar=('[quantization level 0/1/2]'), type=int, choices=[0,1,2], default=0)
|
||||
advparser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently. Outdated. Not recommended.", action='store_true')
|
||||
|
|
@ -8853,6 +8855,7 @@ if __name__ == '__main__':
|
|||
deprecatedgroup.add_argument("--sdnotile", help=argparse.SUPPRESS, action='store_true') # legacy option, see sdtiledvae
|
||||
deprecatedgroup.add_argument("--forceversion", help=argparse.SUPPRESS, action='store_true') #no longer used
|
||||
deprecatedgroup.add_argument("--sdgendefaults", help=argparse.SUPPRESS, action='store_true') # legacy option, see gendefaults
|
||||
deprecatedgroup.add_argument("--flashattention","--flash-attn","-fa", help=argparse.SUPPRESS, action='store_true') #flash attention now default on
|
||||
|
||||
debuggroup = parser.add_argument_group('Debug Commands')
|
||||
debuggroup.add_argument("--testmemory", help=argparse.SUPPRESS, action='store_true')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue