mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
handle SWA conflicting with rewind, increased default SWA padding.
This commit is contained in:
parent
0251c6dbde
commit
ae292c496e
5 changed files with 72 additions and 37 deletions
|
|
@ -1943,39 +1943,6 @@ static bool kcpp_eval_media(llama_context * ctx_llama, const media_chunk & media
|
|||
return true;
|
||||
}
|
||||
|
||||
//counts the number of matching prefix tokens between two sequences, returns percentage matched 0.0 to 1.0
|
||||
float ComputePrefixMatchPercent(std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens)
|
||||
{
|
||||
int match_count = 0;
|
||||
size_t min_length = std::min(current_context_tokens.size(), new_context_tokens.size());
|
||||
for (size_t i = 0; i < min_length; ++i) {
|
||||
if (current_context_tokens[i] == new_context_tokens[i]) {
|
||||
match_count++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handle case where both sequences are empty to avoid division by zero
|
||||
if (min_length == 0) {
|
||||
return 0.0f; // Both empty sequences are considered not matched
|
||||
}
|
||||
return static_cast<float>(match_count) / static_cast<float>(min_length);
|
||||
}
|
||||
|
||||
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
|
||||
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
|
||||
{
|
||||
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < sequence1.size(); ++i) {
|
||||
if (sequence1[i] != sequence2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//given an old GGUF context and a new context that has some middle portion removed,
|
||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||
//returns true if contextshift is doable, executes it if dryrun is false
|
||||
|
|
@ -4325,14 +4292,30 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
else
|
||||
{
|
||||
bool triggersc = kcpp_data->use_smartcontext;
|
||||
bool triggerff = kcpp_data->use_fastforward;
|
||||
if(!blank_prompt) //special case for blank prompts, no fast forward or shifts
|
||||
{
|
||||
if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
|
||||
if(triggerff && !kcpp_data->swa_full)
|
||||
{
|
||||
int goal_npast = ComputeSharedPrefixLength(current_context_tokens,embd_inp);
|
||||
int last_npast = current_context_tokens.size();
|
||||
int swa_limit = kcpp_active_swa_size-4;
|
||||
if(last_npast-goal_npast > swa_limit)
|
||||
{
|
||||
triggerff = false;
|
||||
if (debugmode==1 && !is_quiet)
|
||||
{
|
||||
printf("\n(Rewind of %d-%d=%d would exceed SWA window of %d, doing a full reprocess... to avoid this, disable SWA or increase SWA padding)\n",
|
||||
last_npast,goal_npast,last_npast-goal_npast, swa_limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(triggerff && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
|
||||
{
|
||||
DoContextShifting(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx, false);
|
||||
triggersc = false;
|
||||
}
|
||||
if(kcpp_data->use_fastforward)
|
||||
if(triggerff)
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false, 4);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8306,7 +8306,7 @@ def show_gui():
|
|||
args.noshift = contextshift_var.get()==0
|
||||
args.nofastforward = fastforward_var.get()==0
|
||||
args.useswa = swa_var.get()==1
|
||||
args.swapadding = int(swa_padding_var.get()) if swa_padding_var.get()!="" else 0
|
||||
args.swapadding = int(swa_padding_var.get()) if swa_padding_var.get()!="" else swa_padding_default
|
||||
args.smartcache = (0 if smartcache_var.get()!=1 else int(smartcacheslots_var.get()))
|
||||
args.remotetunnel = remotetunnel_var.get()==1
|
||||
args.foreground = keepforeground.get()==1
|
||||
|
|
@ -8556,7 +8556,7 @@ def show_gui():
|
|||
contextshift_var.set(0 if "noshift" in mydict and mydict["noshift"] else 1)
|
||||
fastforward_var.set(0 if "nofastforward" in mydict and mydict["nofastforward"] else 1)
|
||||
swa_var.set(1 if "useswa" in mydict and mydict["useswa"] else 0)
|
||||
swa_padding_var.set(mydict["swapadding"] if ("swapadding" in mydict and mydict["swapadding"]) else 0)
|
||||
swa_padding_var.set(mydict["swapadding"] if ("swapadding" in mydict) else swa_padding_default)
|
||||
smartcache_var.set(1 if "smartcache" in mydict and mydict["smartcache"] else 0)
|
||||
smartcacheslots_var.set(mydict["smartcache"] if ("smartcache" in mydict and mydict["smartcache"] and int(mydict["smartcache"])>1) else savestate_limit_default)
|
||||
remotetunnel_var.set(1 if "remotetunnel" in mydict and mydict["remotetunnel"] else 0)
|
||||
|
|
|
|||
|
|
@ -1086,4 +1086,47 @@ bool kcpp_string_ends_with(const std::string& str, const std::string& suffix) {
|
|||
std::string kcpp_rstrip(const std::string& s) {
|
||||
size_t end = s.find_last_not_of(" \t\n\r\f\v");
|
||||
return (end == std::string::npos) ? "" : s.substr(0, end + 1);
|
||||
}
|
||||
|
||||
//counts the number of matching prefix tokens between two sequences
|
||||
int ComputeSharedPrefixLength(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b)
|
||||
{
|
||||
size_t min_length = std::min(tokens_a.size(), tokens_b.size());
|
||||
|
||||
int match_count = 0;
|
||||
for (size_t i = 0; i < min_length; ++i) {
|
||||
if (tokens_a[i] != tokens_b[i]) {
|
||||
break;
|
||||
}
|
||||
match_count++;
|
||||
}
|
||||
|
||||
return match_count;
|
||||
}
|
||||
|
||||
//counts the number of matching prefix tokens between two sequences, returns percentage matched 0.0 to 1.0
|
||||
float ComputePrefixMatchPercent(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b)
|
||||
{
|
||||
size_t min_length = std::min(tokens_a.size(), tokens_b.size());
|
||||
|
||||
if (min_length == 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
int match_count = ComputeSharedPrefixLength(tokens_a, tokens_b);
|
||||
return static_cast<float>(match_count) / static_cast<float>(min_length);
|
||||
}
|
||||
|
||||
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
|
||||
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
|
||||
{
|
||||
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < sequence1.size(); ++i) {
|
||||
if (sequence1[i] != sequence2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -77,6 +77,9 @@ std::vector<ggml_backend_dev_t> kcpp_parse_device_list(const std::string & value
|
|||
|
||||
bool kcpp_string_ends_with(const std::string& str, const std::string& suffix);
|
||||
std::string kcpp_rstrip(const std::string& s);
|
||||
int ComputeSharedPrefixLength(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b);
|
||||
float ComputePrefixMatchPercent(const std::vector<int> &tokens_a,const std::vector<int> &tokens_b);
|
||||
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2);
|
||||
|
||||
//duplcated and modified from llava_embd_batch
|
||||
struct kcpp_embd_batch {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
//
|
||||
//kcpp: use a global flag to adjust swa padding
|
||||
static int kcpp_extra_swa_padding = 0;
|
||||
static int kcpp_active_swa_size = 0;
|
||||
|
||||
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
|
|
@ -55,6 +56,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|||
size_swa += 128;
|
||||
size_swa += kcpp_extra_swa_padding;
|
||||
size_swa = GGML_PAD(size_swa, n_pad);
|
||||
if (size_swa > size_base) {
|
||||
size_swa = size_base;
|
||||
}
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
|
|
@ -64,6 +68,8 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|||
size_swa = size_base;
|
||||
}
|
||||
|
||||
kcpp_active_swa_size = size_swa;
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue