allow quantkv with contextshift

This commit is contained in:
Concedo 2025-03-16 21:48:42 +08:00
parent e466ce65e2
commit 6888f5495d
2 changed files with 6 additions and 7 deletions

View file

@ -1092,7 +1092,6 @@ def load_model(model_filename):
if args.quantkv>0:
inputs.quant_k = inputs.quant_v = args.quantkv
inputs.flash_attention = True
inputs.use_contextshift = 0
else:
inputs.quant_k = inputs.quant_v = 0
inputs.blasbatchsize = args.blasbatchsize
@ -3682,7 +3681,7 @@ def show_gui():
fastforward.set(1)
smartcontextbox.grid_remove()
if contextshift.get()==0 and flashattention.get()==1:
if flashattention.get()==1:
qkvslider.grid()
qkvlabel.grid()
noqkvlabel.grid_remove()
@ -3692,7 +3691,7 @@ def show_gui():
noqkvlabel.grid()
def toggleflashattn(a,b,c):
if contextshift.get()==0 and flashattention.get()==1:
if flashattention.get()==1:
qkvslider.grid()
qkvlabel.grid()
noqkvlabel.grid_remove()
@ -3906,7 +3905,7 @@ def show_gui():
item.grid_remove()
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28, command=toggleflashattn, tooltiptxt="Enable flash attention for GGUF models.")
noqkvlabel = makelabel(tokens_tab,"Requirments Not Met",31,0,"Requires FlashAttention ENABLED and ContextShift DISABLED.")
noqkvlabel = makelabel(tokens_tab,"Requirments Not Met",31,0,"Requires FlashAttention ENABLED.")
noqkvlabel.configure(text_color="#ff5555")
qkvslider,qkvlabel,qkvtitle = makeslider(tokens_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 2, 30, set=0,tooltip="Enable quantization of KV cache.\nRequires FlashAttention and disables ContextShift.")
makecheckbox(tokens_tab, "No BOS Token", nobostoken_var, 33, tooltiptxt="Prevents BOS token from being added at the start of any prompt. Usually NOT recommended for most models.")
@ -4109,7 +4108,7 @@ def show_gui():
args.quiet = quietmode.get()==1
args.nocertify = nocertifymode.get()==1
args.nomodel = nomodel.get()==1
if contextshift.get()==0 and flashattention.get()==1:
if flashattention.get()==1:
args.quantkv = quantkv_var.get()
else:
args.quantkv = 0