handle SWA conflicting with rewind, increased default SWA padding.

This commit is contained in:
Concedo 2026-04-16 17:00:26 +08:00
parent 0251c6dbde
commit ae292c496e
5 changed files with 72 additions and 37 deletions

View file

@ -1943,39 +1943,6 @@ static bool kcpp_eval_media(llama_context * ctx_llama, const media_chunk & media
return true;
}
//counts the number of matching prefix tokens between two sequences, returns percentage matched 0.0 to 1.0
float ComputePrefixMatchPercent(std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens)
{
int match_count = 0;
size_t min_length = std::min(current_context_tokens.size(), new_context_tokens.size());
for (size_t i = 0; i < min_length; ++i) {
if (current_context_tokens[i] == new_context_tokens[i]) {
match_count++;
} else {
break;
}
}
// Handle case where both sequences are empty to avoid division by zero
if (min_length == 0) {
return 0.0f; // Both empty sequences are considered not matched
}
return static_cast<float>(match_count) / static_cast<float>(min_length);
}
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
{
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
return false;
}
for (size_t i = 0; i < sequence1.size(); ++i) {
if (sequence1[i] != sequence2[i]) {
return false;
}
}
return true;
}
//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
//returns true if contextshift is doable, executes it if dryrun is false
@ -4325,14 +4292,30 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
else
{
bool triggersc = kcpp_data->use_smartcontext;
bool triggerff = kcpp_data->use_fastforward;
if(!blank_prompt) //special case for blank prompts, no fast forward or shifts
{
if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
if(triggerff && !kcpp_data->swa_full)
{
int goal_npast = ComputeSharedPrefixLength(current_context_tokens,embd_inp);
int last_npast = current_context_tokens.size();
int swa_limit = kcpp_active_swa_size-4;
if(last_npast-goal_npast > swa_limit)
{
triggerff = false;
if (debugmode==1 && !is_quiet)
{
printf("\n(Rewind of %d-%d=%d would exceed SWA window of %d, doing a full reprocess... to avoid this, disable SWA or increase SWA padding)\n",
last_npast,goal_npast,last_npast-goal_npast, swa_limit);
}
}
}
if(triggerff && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
{
DoContextShifting(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx, false);
triggersc = false;
}
if(kcpp_data->use_fastforward)
if(triggerff)
{
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false, 4);
}

View file

@ -8306,7 +8306,7 @@ def show_gui():
args.noshift = contextshift_var.get()==0
args.nofastforward = fastforward_var.get()==0
args.useswa = swa_var.get()==1
args.swapadding = int(swa_padding_var.get()) if swa_padding_var.get()!="" else 0
args.swapadding = int(swa_padding_var.get()) if swa_padding_var.get()!="" else swa_padding_default
args.smartcache = (0 if smartcache_var.get()!=1 else int(smartcacheslots_var.get()))
args.remotetunnel = remotetunnel_var.get()==1
args.foreground = keepforeground.get()==1
@ -8556,7 +8556,7 @@ def show_gui():
contextshift_var.set(0 if "noshift" in mydict and mydict["noshift"] else 1)
fastforward_var.set(0 if "nofastforward" in mydict and mydict["nofastforward"] else 1)
swa_var.set(1 if "useswa" in mydict and mydict["useswa"] else 0)
swa_padding_var.set(mydict["swapadding"] if ("swapadding" in mydict and mydict["swapadding"]) else 0)
swa_padding_var.set(mydict["swapadding"] if ("swapadding" in mydict) else swa_padding_default)
smartcache_var.set(1 if "smartcache" in mydict and mydict["smartcache"] else 0)
smartcacheslots_var.set(mydict["smartcache"] if ("smartcache" in mydict and mydict["smartcache"] and int(mydict["smartcache"])>1) else savestate_limit_default)
remotetunnel_var.set(1 if "remotetunnel" in mydict and mydict["remotetunnel"] else 0)

View file

@ -1086,4 +1086,47 @@ bool kcpp_string_ends_with(const std::string& str, const std::string& suffix) {
std::string kcpp_rstrip(const std::string& s) {
size_t end = s.find_last_not_of(" \t\n\r\f\v");
return (end == std::string::npos) ? "" : s.substr(0, end + 1);
}
//counts the number of matching prefix tokens between two sequences
int ComputeSharedPrefixLength(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b)
{
size_t min_length = std::min(tokens_a.size(), tokens_b.size());
int match_count = 0;
for (size_t i = 0; i < min_length; ++i) {
if (tokens_a[i] != tokens_b[i]) {
break;
}
match_count++;
}
return match_count;
}
//counts the number of matching prefix tokens between two sequences, returns percentage matched 0.0 to 1.0
float ComputePrefixMatchPercent(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b)
{
size_t min_length = std::min(tokens_a.size(), tokens_b.size());
if (min_length == 0) {
return 0.0f;
}
int match_count = ComputeSharedPrefixLength(tokens_a, tokens_b);
return static_cast<float>(match_count) / static_cast<float>(min_length);
}
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
{
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
return false;
}
for (size_t i = 0; i < sequence1.size(); ++i) {
if (sequence1[i] != sequence2[i]) {
return false;
}
}
return true;
}

View file

@ -77,6 +77,9 @@ std::vector<ggml_backend_dev_t> kcpp_parse_device_list(const std::string & value
bool kcpp_string_ends_with(const std::string& str, const std::string& suffix);
std::string kcpp_rstrip(const std::string& s);
int ComputeSharedPrefixLength(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b);
float ComputePrefixMatchPercent(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b);
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2);
//duplcated and modified from llava_embd_batch
struct kcpp_embd_batch {

View file

@ -12,6 +12,7 @@
//
//kcpp: use a global flag to adjust swa padding
static int kcpp_extra_swa_padding = 0;
static int kcpp_active_swa_size = 0;
llama_kv_cache_iswa::llama_kv_cache_iswa(
const llama_model & model,
@ -55,6 +56,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
size_swa += 128;
size_swa += kcpp_extra_swa_padding;
size_swa = GGML_PAD(size_swa, n_pad);
if (size_swa > size_base) {
size_swa = size_base;
}
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
@ -64,6 +68,8 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
size_swa = size_base;
}
kcpp_active_swa_size = size_swa;
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache>(