revamped smart context for llama models

This commit is contained in:
Concedo 2023-10-28 12:59:08 +08:00
parent c2f675133d
commit 15f525c580
5 changed files with 90 additions and 34 deletions

View file

@ -247,6 +247,17 @@ static std::string RemoveBell(const std::string & input) //removes the bell char
return word2;
}
static std::string print_tok_vec_str(std::vector<int> &embd)
{
std::string tmp = "";
for (auto id : embd)
{
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
}
::utreplace(tmp, "\n", "\\n");
return tmp;
}
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
{
@ -572,7 +583,7 @@ static void load_grammar(const std::string & gammarstr)
//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens)
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt)
{
//scan from start old and new ctx, until first mismatch found, save as p0
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
@ -580,20 +591,63 @@ void PurgeMissingTokens(std::vector<int> &current_context_tokens, std::vector<in
//if passed, save beginning of LCQ from old ctx as p1
//remove all tokens from old ctx between p0 and p1, updating both arrays and kv, then continue as normal
// int trimstart = 0;
const int ShortfallThreshold = 256; //dont trigger shifting if the distance between trimstart and currhead < this
const int SlackAllowance = 32; //in case the end text is slightly modified, be forgiving
// const int n_keep = 0;
// const int n_left = n_past - n_keep - 1;
// const int n_discard = n_left/2;
int trimstart = 0;
int new_tokens_len = new_context_tokens.size();
bool purgeneeded = true;
// printf("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
// n_past, n_left, nctx, n_keep, n_discard);
for (int i = 0; i < current_context_tokens.size(); ++i)
{
if (current_context_tokens[i] == new_context_tokens[i])
{
trimstart += 1;
}
else
{
break;
}
if ((i + 2) >= new_tokens_len)
{
purgeneeded = false;
break; //no surgery required
}
}
// llama_kv_cache_seq_rm (llama_ctx_v4, 0, n_keep + 1 , n_keep + n_discard + 1);
// llama_kv_cache_seq_shift(llama_ctx_v4, 0, n_keep + 1 + n_discard, n_past, -n_discard);
// n_past -= n_discard;
if(!purgeneeded || new_tokens_len < 6 || current_context_tokens.size() < 6 || new_tokens_len - trimstart < ShortfallThreshold)
{
return; //no purge is needed
}
// printf("after swap: n_past = %d\n", n_past);
//at least this many tokens need to match, otherwise don't bother trimming
const int LCQTokThreshold = std::max((new_tokens_len - trimstart) - (genamt+SlackAllowance), ShortfallThreshold-SlackAllowance);
auto curr_ctx_without_memory = std::vector<int>(current_context_tokens.begin() + trimstart, current_context_tokens.end());
auto new_ctx_without_memory = std::vector<int>(new_context_tokens.begin() + trimstart, new_context_tokens.end());
auto shared = LongestCommonSubseq(curr_ctx_without_memory, new_ctx_without_memory);
if (shared.size() > LCQTokThreshold && ArrStartWith(new_ctx_without_memory, shared)) // enough tokens in common
{
int found = ArrFindIndexOf(current_context_tokens,shared);
if(found>=0 && found > trimstart)
{
//extract the unwanted tokens out from context and KV
int diff = found - trimstart;
llama_kv_cache_seq_rm(llama_ctx_v4, 0, trimstart + 1, trimstart + diff + 1);
llama_kv_cache_seq_shift(llama_ctx_v4, 0, trimstart + diff + 1, -1, -diff);
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
{
current_context_tokens[i - diff] = current_context_tokens[i];
}
printf("\n[Smart Context Pro: Erased %d tokens at position %d]", diff, trimstart+1);
current_context_tokens.resize(current_context_tokens.size() - diff - 1);
}
}
}
@ -1398,15 +1452,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0;
PurgeMissingTokens(current_context_tokens, embd_inp);
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
}
else
{
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext, false);
bool triggersc = useSmartContext;
if(useSmartContext && file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
{
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length);
triggersc = false;
}
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
}
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
@ -1545,23 +1603,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
std::string outstr = "";
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
std::string tmp = "";
for (auto id : embd_inp)
{
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
}
::utreplace(tmp, "\n", "\\n");
outstr += tmp;
outstr += print_tok_vec_str(embd_inp);
outstr += "\n\n[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
tmp = "";
for (auto id : current_context_tokens)
{
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
}
::utreplace(tmp, "\n", "\\n");
outstr += tmp;
outstr += print_tok_vec_str(current_context_tokens);
printf("%s\n\n", RemoveBell(outstr).c_str());
}