mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
revamped smart context for llama models
This commit is contained in:
parent
c2f675133d
commit
15f525c580
5 changed files with 90 additions and 34 deletions
|
@ -247,6 +247,17 @@ static std::string RemoveBell(const std::string & input) //removes the bell char
|
|||
return word2;
|
||||
}
|
||||
|
||||
static std::string print_tok_vec_str(std::vector<int> &embd)
|
||||
{
|
||||
std::string tmp = "";
|
||||
for (auto id : embd)
|
||||
{
|
||||
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||
}
|
||||
::utreplace(tmp, "\n", "\\n");
|
||||
return tmp;
|
||||
}
|
||||
|
||||
|
||||
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
|
||||
{
|
||||
|
@ -572,7 +583,7 @@ static void load_grammar(const std::string & gammarstr)
|
|||
|
||||
//given an old GGUF context and a new context that has some middle portion removed,
|
||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||
void PurgeMissingTokens(std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens)
|
||||
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt)
|
||||
{
|
||||
//scan from start old and new ctx, until first mismatch found, save as p0
|
||||
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
||||
|
@ -580,20 +591,63 @@ void PurgeMissingTokens(std::vector<int> ¤t_context_tokens, std::vector<in
|
|||
//if passed, save beginning of LCQ from old ctx as p1
|
||||
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
|
||||
|
||||
// int trimstart = 0;
|
||||
const int ShortfallThreshold = 256; //dont trigger shifting if the distance between trimstart and currhead < this
|
||||
const int SlackAllowance = 32; //in case the end text is slightly modified, be forgiving
|
||||
|
||||
// const int n_keep = 0;
|
||||
// const int n_left = n_past - n_keep - 1;
|
||||
// const int n_discard = n_left/2;
|
||||
int trimstart = 0;
|
||||
int new_tokens_len = new_context_tokens.size();
|
||||
bool purgeneeded = true;
|
||||
|
||||
// printf("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
// n_past, n_left, nctx, n_keep, n_discard);
|
||||
for (int i = 0; i < current_context_tokens.size(); ++i)
|
||||
{
|
||||
if (current_context_tokens[i] == new_context_tokens[i])
|
||||
{
|
||||
trimstart += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
if ((i + 2) >= new_tokens_len)
|
||||
{
|
||||
purgeneeded = false;
|
||||
break; //no surgery required
|
||||
}
|
||||
}
|
||||
|
||||
// llama_kv_cache_seq_rm (llama_ctx_v4, 0, n_keep + 1 , n_keep + n_discard + 1);
|
||||
// llama_kv_cache_seq_shift(llama_ctx_v4, 0, n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
// n_past -= n_discard;
|
||||
if(!purgeneeded || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < ShortfallThreshold)
|
||||
{
|
||||
return; //no purge is needed
|
||||
}
|
||||
|
||||
// printf("after swap: n_past = %d\n", n_past);
|
||||
//at least this many tokens need to match, otherwise don't bother trimming
|
||||
const int LCQTokThreshold = std::max((new_tokens_len - trimstart) - (genamt+SlackAllowance), ShortfallThreshold-SlackAllowance);
|
||||
|
||||
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
|
||||
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
|
||||
|
||||
auto shared = LongestCommonSubseq(curr_ctx_without_memory, new_ctx_without_memory);
|
||||
|
||||
if (shared.size() > LCQTokThreshold && ArrStartWith(new_ctx_without_memory, shared)) // enough tokens in common
|
||||
{
|
||||
int found = ArrFindIndexOf(current_context_tokens,shared);
|
||||
if(found>=0 && found > trimstart)
|
||||
{
|
||||
//extract the unwanted tokens out from context and KV
|
||||
int diff = found - trimstart;
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, trimstart + 1, trimstart + diff + 1);
|
||||
llama_kv_cache_seq_shift(llama_ctx_v4, 0, trimstart + diff + 1, -1, -diff);
|
||||
|
||||
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
|
||||
{
|
||||
current_context_tokens[i - diff] = current_context_tokens[i];
|
||||
}
|
||||
|
||||
printf("\n[Smart Context Pro: Erased %d tokens at position %d]", diff, trimstart+1);
|
||||
|
||||
current_context_tokens.resize(current_context_tokens.size() - diff - 1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1398,15 +1452,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
n_past = 0;
|
||||
|
||||
PurgeMissingTokens(current_context_tokens, embd_inp);
|
||||
|
||||
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext, false);
|
||||
bool triggersc = useSmartContext;
|
||||
if(useSmartContext && file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
{
|
||||
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length);
|
||||
triggersc = false;
|
||||
}
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
|
||||
}
|
||||
|
||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||
|
@ -1545,23 +1603,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
std::string outstr = "";
|
||||
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
|
||||
|
||||
std::string tmp = "";
|
||||
for (auto id : embd_inp)
|
||||
{
|
||||
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||
}
|
||||
::utreplace(tmp, "\n", "\\n");
|
||||
outstr += tmp;
|
||||
|
||||
outstr += print_tok_vec_str(embd_inp);
|
||||
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
|
||||
tmp = "";
|
||||
for (auto id : current_context_tokens)
|
||||
{
|
||||
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||
}
|
||||
::utreplace(tmp, "\n", "\\n");
|
||||
outstr += tmp;
|
||||
outstr += print_tok_vec_str(current_context_tokens);
|
||||
printf("%s\n\n", RemoveBell(outstr).c_str());
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue