swa options now available

This commit is contained in:
Concedo 2025-05-24 11:50:37 +08:00
parent 748dfcc2e4
commit ec04115ae9
4 changed files with 67 additions and 50 deletions

View file

@ -1927,10 +1927,16 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
kcpp_data->use_smartcontext = inputs.use_smartcontext;
kcpp_data->use_contextshift = inputs.use_contextshift;
kcpp_data->use_fastforward = inputs.use_fastforward;
kcpp_data->swa_full = !inputs.swa_support;//(inputs.use_fastforward || inputs.use_contextshift)?true:false;
if(!kcpp_data->swa_full)
{
printf("\n!!!!!!!!!!!!!!!!!!!\nExperimental FLAG - SWA SUPPORT IS ENABLED!\n!!!!!!!!!!!!!!!!!!!\n");
kcpp_data->swa_full = !inputs.swa_support;
if (!kcpp_data->swa_full) {
if (inputs.use_contextshift) {
kcpp_data->swa_full = true; //cannot use SWA
printf("\nSWA Mode IS DISABLED!\nSWA Mode Cannot be used with Context Shifting!\n");
} else if (inputs.use_fastforward) {
printf("\nSWA Mode is ENABLED!\nNote that using SWA Mode with Fast Forwarding can lead to degraded recall!\n");
} else {
printf("\nSWA Mode IS ENABLED!\n");
}
}
debugmode = inputs.debugmode;
draft_ctx = nullptr;

View file

@ -9932,7 +9932,9 @@ Current version indicated by LITEVER below.
function toggleclaudemodel()
{
if (document.getElementById("custom_claude_model").value.toLowerCase().includes("claude-3"))
if (document.getElementById("custom_claude_model").value.toLowerCase().includes("claude-3")
|| document.getElementById("custom_claude_model").value.toLowerCase().includes("claude-sonnet-4")
|| document.getElementById("custom_claude_model").value.toLowerCase().includes("claude-opus-4"))
{
document.getElementById("claudesystemprompt").classList.remove("hidden");
document.getElementById("claudejailbreakprompt").classList.remove("hidden");
@ -15815,7 +15817,9 @@ Current version indicated by LITEVER below.
}
else if (custom_claude_key != "")//handle for Claude
{
let claudev3mode = custom_claude_model.toLowerCase().includes("claude-3");
let claudev3mode = custom_claude_model.toLowerCase().includes("claude-3")
|| custom_claude_model.toLowerCase().includes("claude-sonnet-4")
|| custom_claude_model.toLowerCase().includes("claude-opus-4");
let actualep = (custom_claude_endpoint + (claudev3mode?claude_submit_endpoint_v3:claude_submit_endpoint));
let targetep = actualep;
if(custom_claude_endpoint.toLowerCase().includes("api.anthropic.com"))
@ -23892,6 +23896,8 @@ Current version indicated by LITEVER below.
<option value="claude-3-5-sonnet-latest" selected="selected">claude-3-5-sonnet-latest</option>
<option value="claude-3-5-haiku-20241022">claude-3-5-haiku-20241022</option>
<option value="claude-3-7-sonnet-20250219">claude-3-7-sonnet-20250219</option>
<option value="claude-sonnet-4-latest">claude-sonnet-4-latest</option>
<option value="claude-opus-4-latest">claude-opus-4-latest</option>
</select>
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="claudefetchlist" onclick="claude_fetch_models()">Fetch List</button>
<input type="checkbox" title="Add endpoint version" id="claudeaddversion" onchange="" checked>

View file

@ -1249,7 +1249,7 @@ def load_model(model_filename):
inputs.override_kv = args.overridekv.encode("UTF-8") if args.overridekv else "".encode("UTF-8")
inputs.override_tensors = args.overridetensors.encode("UTF-8") if args.overridetensors else "".encode("UTF-8")
inputs.check_slowness = (not args.highpriority and os.name == 'nt' and 'Intel' in platform.processor())
inputs.swa_support = args.experiment_swa
inputs.swa_support = args.useswa
inputs = set_backend_props(inputs)
ret = handle.load_model(inputs)
return ret
@ -2208,7 +2208,7 @@ ws ::= | " " | "\n" [ \t]{0,20}
user_end = assistant_message_start
if chosen_tool=="auto":
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool? (One word response: yes or no):")
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool above? (One word response: yes or no):")
# note: message string already contains the instruct start tag!
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
temp_poll = {
@ -4088,11 +4088,12 @@ def show_gui():
tensor_split_str_vars = ctk.StringVar(value="")
rowsplit_var = ctk.IntVar()
contextshift = ctk.IntVar(value=1)
fastforward = ctk.IntVar(value=1)
remotetunnel = ctk.IntVar(value=0)
smartcontext = ctk.IntVar()
flashattention = ctk.IntVar(value=0)
contextshift_var = ctk.IntVar(value=1)
fastforward_var = ctk.IntVar(value=1)
swa_var = ctk.IntVar(value=0)
remotetunnel_var = ctk.IntVar(value=0)
smartcontext_var = ctk.IntVar()
flashattention_var = ctk.IntVar(value=0)
context_var = ctk.IntVar()
customrope_var = ctk.IntVar()
customrope_scale = ctk.StringVar(value="1.0")
@ -4459,7 +4460,7 @@ def show_gui():
pass
def changed_gpulayers_estimate(*args):
predicted_gpu_layers = autoset_gpu_layers(int(contextsize_text[context_var.get()]),(sd_quant_var.get()==1),int(blasbatchsize_values[int(blas_size_var.get())]),(quantkv_var.get() if flashattention.get()==1 else 0))
predicted_gpu_layers = autoset_gpu_layers(int(contextsize_text[context_var.get()]),(sd_quant_var.get()==1),int(blasbatchsize_values[int(blas_size_var.get())]),(quantkv_var.get() if flashattention_var.get()==1 else 0))
max_gpu_layers = (f"/{modelfile_extracted_meta[1][0]+3}" if (modelfile_extracted_meta and modelfile_extracted_meta[1] and modelfile_extracted_meta[1][0]!=0) else "")
index = runopts_var.get()
gpu_be = (index == "Use Vulkan" or index == "Use Vulkan (Old CPU)" or index == "Use CLBlast" or index == "Use CLBlast (Old CPU)" or index == "Use CLBlast (Older CPU)" or index == "Use CuBLAS" or index == "Use hipBLAS (ROCm)")
@ -4507,21 +4508,25 @@ def show_gui():
gpu_choice_var.trace("w", changed_gpu_choice_var)
gpulayers_var.trace("w", changed_gpulayers_estimate)
def toggleswa(a,b,c):
if swa_var.get()==1:
contextshift_var.set(0)
def togglefastforward(a,b,c):
if fastforward.get()==0:
contextshift.set(0)
smartcontext.set(0)
togglectxshift(1,1,1)
if fastforward_var.get()==0:
contextshift_var.set(0)
smartcontext_var.set(0)
def togglectxshift(a,b,c):
if contextshift.get()==0:
if contextshift_var.get()==0:
smartcontextbox.grid()
else:
fastforward.set(1)
fastforward_var.set(1)
swa_var.set(0)
smartcontextbox.grid_remove()
qkvslider.grid()
qkvlabel.grid()
if flashattention.get()==0 and quantkv_var.get()>0:
if flashattention_var.get()==0 and quantkv_var.get()>0:
noqkvlabel.grid()
else:
noqkvlabel.grid_remove()
@ -4530,7 +4535,7 @@ def show_gui():
def toggleflashattn(a,b,c):
qkvslider.grid()
qkvlabel.grid()
if flashattention.get()==0 and quantkv_var.get()>0:
if flashattention_var.get()==0 and quantkv_var.get()>0:
noqkvlabel.grid()
else:
noqkvlabel.grid_remove()
@ -4636,15 +4641,15 @@ def show_gui():
quick_boxes = {
"Launch Browser": [launchbrowser, "Launches your default browser after model loading is complete"],
"Use MMAP": [usemmap, "Use mmap to load models if enabled, model will not be unloadable"],
"Use ContextShift": [contextshift, "Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info."],
"Remote Tunnel": [remotetunnel, "Creates a trycloudflare tunnel.\nAllows you to access koboldcpp from other devices over an internet URL."],
"Use ContextShift": [contextshift_var, "Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info."],
"Remote Tunnel": [remotetunnel_var, "Creates a trycloudflare tunnel.\nAllows you to access koboldcpp from other devices over an internet URL."],
"Quiet Mode": [quietmode, "Prevents all generation related terminal output from being displayed."]
}
for idx, (name, properties) in enumerate(quick_boxes.items()):
makecheckbox(quick_tab, name, properties[0], int(idx/2) + 20, idx % 2, tooltiptxt=properties[1])
makecheckbox(quick_tab, "Use FlashAttention", flashattention, 22, 1, tooltiptxt="Enable flash attention for GGUF models.")
makecheckbox(quick_tab, "Use FlashAttention", flashattention_var, 22, 1, tooltiptxt="Enable flash attention for GGUF models.")
# context size
makeslider(quick_tab, "Context Size:", contextsize_text, context_var, 0, len(contextsize_text)-1, 30, width=280, set=5,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
@ -4713,9 +4718,10 @@ def show_gui():
# Tokens Tab
tokens_tab = tabcontent["Tokens"]
# tokens checkboxes
smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.")
makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift)
makecheckbox(tokens_tab, "Use FastForwarding", fastforward, 3,tooltiptxt="Use fast forwarding to recycle previous context (always reprocess if disabled).\nRecommended.", command=togglefastforward)
smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext_var, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.")
makecheckbox(tokens_tab, "Use ContextShift", contextshift_var, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift)
makecheckbox(tokens_tab, "Use FastForwarding", fastforward_var, 3,tooltiptxt="Use fast forwarding to recycle previous context (always reprocess if disabled).\nRecommended.", command=togglefastforward)
makecheckbox(tokens_tab, "Use Sliding Window Attention (SWA)", swa_var, 4,tooltiptxt="Allows Sliding Window Attention (SWA) KV Cache, which saves memory but cannot be used with context shifting.", command=toggleswa)
# context size
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, len(contextsize_text)-1, 18, width=280, set=5,tooltip="What is the maximum context size to support. Model specific. You cannot exceed it.\nLarger contexts require more memory, and not all models support it.")
@ -4732,7 +4738,7 @@ def show_gui():
else:
item.grid_remove()
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28, command=toggleflashattn, tooltiptxt="Enable flash attention for GGUF models.")
makecheckbox(tokens_tab, "Use FlashAttention", flashattention_var, 28, command=toggleflashattn, tooltiptxt="Enable flash attention for GGUF models.")
noqkvlabel = makelabel(tokens_tab,"(Note: QuantKV works best with flash attention)",28,0,"Only K cache can be quantized, and performance can suffer.\nIn some cases, it might even use more VRAM when doing a full offload.",padx=160)
noqkvlabel.configure(text_color="#ff5555")
qkvslider,qkvlabel,qkvtitle = makeslider(tokens_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 2, 30, set=0,tooltip="Enable quantization of KV cache.\nRequires FlashAttention for full effect, otherwise only K cache is quantized.")
@ -4781,7 +4787,7 @@ def show_gui():
makelabelentry(network_tab, "Host: ", host_var, 2, 150,tooltip="Select a specific host interface to bind to.\n(Defaults to all)")
makecheckbox(network_tab, "Multiuser Mode", multiuser_var, 3,tooltiptxt="Allows requests by multiple different clients to be queued and handled in sequence.")
makecheckbox(network_tab, "Remote Tunnel", remotetunnel, 3, 1,tooltiptxt="Creates a trycloudflare tunnel.\nAllows you to access koboldcpp from other devices over an internet URL.")
makecheckbox(network_tab, "Remote Tunnel", remotetunnel_var, 3, 1,tooltiptxt="Creates a trycloudflare tunnel.\nAllows you to access koboldcpp from other devices over an internet URL.")
makecheckbox(network_tab, "Quiet Mode", quietmode, 4,tooltiptxt="Prevents all generation related terminal output from being displayed.")
makecheckbox(network_tab, "NoCertify Mode (Insecure)", nocertifymode, 4, 1,tooltiptxt="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.")
makecheckbox(network_tab, "Shared Multiplayer", multiplayer_var, 5,tooltiptxt="Hosts a shared multiplayer session that others can join.")
@ -4893,13 +4899,12 @@ def show_gui():
# extra tab
extra_tab = tabcontent["Extra"]
makelabel(extra_tab, "Unpack KoboldCpp to a local directory to modify its files.", 1, 0)
makelabel(extra_tab, "You can also launch via koboldcpp.py for faster startup.", 2, 0)
ctk.CTkButton(extra_tab , text = "Unpack KoboldCpp To Folder", command = unpack_to_dir ).grid(row=3,column=0, stick="w", padx= 8, pady=2)
makelabel(extra_tab, "Export as launcher .kcppt template (Expert Only)", 4, 0,tooltiptxt="Creates a KoboldCpp launch template for others to use.\nEmbeds JSON files directly into exported file when saving.\nWhen loaded, forces the backend to be automatically determined.\nWarning! Not recommended for beginners!")
ctk.CTkButton(extra_tab , text = "Generate LaunchTemplate", command = kcpp_export_template ).grid(row=5,column=0, stick="w", padx= 8, pady=2)
makelabel(extra_tab, "Extract KoboldCpp Files", 3, 0,tooltiptxt="Unpack KoboldCpp to a local directory to modify its files. You can also launch via koboldcpp.py for faster startup.")
ctk.CTkButton(extra_tab , text = "Unpack KoboldCpp To Folder", command = unpack_to_dir ).grid(row=3,column=0, stick="w", padx= 170, pady=2)
makelabel(extra_tab, "Export as .kcppt template", 4, 0,tooltiptxt="Creates a KoboldCpp launch template for others to use.\nEmbeds JSON files directly into exported file when saving.\nWhen loaded, forces the backend to be automatically determined.\nWarning! Not recommended for beginners!")
ctk.CTkButton(extra_tab , text = "Generate LaunchTemplate", command = kcpp_export_template ).grid(row=4,column=0, stick="w", padx= 170, pady=2)
makelabel(extra_tab, "Analyze GGUF Metadata", 6, 0,tooltiptxt="Reads the metadata, weight types and tensor names in any GGUF file.")
ctk.CTkButton(extra_tab , text = "Analyze GGUF", command = analyze_gguf_model_wrapper ).grid(row=7,column=0, stick="w", padx= 8, pady=2)
ctk.CTkButton(extra_tab , text = "Analyze GGUF", command = analyze_gguf_model_wrapper ).grid(row=6,column=0, stick="w", padx= 170, pady=2)
if sys.platform == "linux":
def togglezenity(a,b,c):
global zenity_permitted
@ -4936,11 +4941,12 @@ def show_gui():
args.launch = launchbrowser.get()==1
args.highpriority = highpriority.get()==1
args.usemmap = usemmap.get()==1
args.smartcontext = smartcontext.get()==1
args.flashattention = flashattention.get()==1
args.noshift = contextshift.get()==0
args.nofastforward = fastforward.get()==0
args.remotetunnel = remotetunnel.get()==1
args.smartcontext = smartcontext_var.get()==1
args.flashattention = flashattention_var.get()==1
args.noshift = contextshift_var.get()==0
args.nofastforward = fastforward_var.get()==0
args.useswa = swa_var.get()==1
args.remotetunnel = remotetunnel_var.get()==1
args.foreground = keepforeground.get()==1
args.cli = terminalonly.get()==1
args.quiet = quietmode.get()==1
@ -5123,11 +5129,12 @@ def show_gui():
launchbrowser.set(1 if "launch" in dict and dict["launch"] else 0)
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
usemmap.set(1 if "usemmap" in dict and dict["usemmap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
fastforward.set(0 if "nofastforward" in dict and dict["nofastforward"] else 1)
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
smartcontext_var.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
flashattention_var.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
contextshift_var.set(0 if "noshift" in dict and dict["noshift"] else 1)
fastforward_var.set(0 if "nofastforward" in dict and dict["nofastforward"] else 1)
swa_var.set(1 if "useswa" in dict and dict["useswa"] else 0)
remotetunnel_var.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
keepforeground.set(1 if "foreground" in dict and dict["foreground"] else 0)
terminalonly.set(1 if "cli" in dict and dict["cli"] else 0)
quietmode.set(1 if "quiet" in dict and dict["quiet"] else 0)
@ -6876,6 +6883,7 @@ if __name__ == '__main__':
advparser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
advparser.add_argument("--noshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true')
advparser.add_argument("--nofastforward", help="If set, do not attempt to fast forward GGUF context (always reprocess). Will also enable noshift", action='store_true')
advparser.add_argument("--useswa", help="If set, allows Sliding Window Attention (SWA) KV Cache, which saves memory but cannot be used with context shifting.", action='store_true')
compatgroup3 = advparser.add_mutually_exclusive_group()
compatgroup3.add_argument("--usemmap", help="If set, uses mmap to load model.", action='store_true')
advparser.add_argument("--usemlock", help="Enables mlock, preventing the RAM used to load the model from being paged out. Not usually recommended.", action='store_true')
@ -6968,9 +6976,6 @@ if __name__ == '__main__':
admingroup.add_argument("--adminpassword", metavar=('[password]'), help="Require a password to access admin functions. You are strongly advised to use one for publically accessible instances!", default=None)
admingroup.add_argument("--admindir", metavar=('[directory]'), help="Specify a directory to look for .kcpps configs in, which can be used to swap models.", default="")
experimentgroup = parser.add_argument_group('Experimental Commands, can change or break any time!')
experimentgroup.add_argument("--experiment_swa", help="Enables SWA mode. There are no safety checks.", action='store_true')
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+')

View file

@ -689,7 +689,7 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
}
if (n_attended < std::min<int>(n_swa, pmin)) {
LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
//LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
}
}