fast forwarding for rwkv for unmodified contexts

This commit is contained in:
Concedo 2023-04-19 15:09:35 +08:00
parent f39def81d4
commit 45ec09d31b
8 changed files with 70 additions and 46 deletions

View file

@ -236,7 +236,8 @@ void print_tok_vec(std::vector<float> &embd)
}
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp,
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext, bool useSmartContext)
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
bool useSmartContext, const bool requireFullSubset)
{
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext
@ -244,13 +245,11 @@ void print_tok_vec(std::vector<float> &embd)
const float SCTruncationRatio = 0.5; //ratio for how many tokens to fast forward
const int SCTokThreshold = 32 + (nctx*0.05); //how many tokens of similarity triggers smartcontext
// printf("\nORIGINAL CTX:\n");
// print_tok_vec(current_context_tokens);
// printf("\nORIGINAL EMBD:\n");
// print_tok_vec(embd_inp);
//fast forward the past based on identical tokens, stop once a divergence is noted
int embd_inp_len = embd_inp.size();
bool fastforwardok = true;
for (int i = 0; i < current_context_tokens.size(); ++i)
{
if (current_context_tokens[i] == embd_inp[i])
@ -260,37 +259,48 @@ void print_tok_vec(std::vector<float> &embd)
}
else
{
if(requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
{
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
n_past = 0;
fastforwardok = false;
}
break;
}
if ((i + 2) >= embd_inp_len)
if (requireFullSubset) //RWKV can only do this if embd_inp contains everything in current context
{
break;
if (i >= embd_inp_len)
{
last_n_tokens.erase(last_n_tokens.end() - n_past, last_n_tokens.end());
n_past = 0;
fastforwardok = false;
break;
}
}
else
{
if ((i + 2) >= embd_inp_len)
{
break;
}
}
}
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
embd_inp_len = embd_inp.size();
if(fastforwardok)
{
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
embd_inp_len = embd_inp.size();
}
//smart context mode, detect if we have a shifted context at max length
//requirement: previous context was at least nctx/2 longer than current,
//mode is on, and current context already maxed.
// printf("\nconds: %d %d %d\n",current_context_tokens.size() >= nctx*0.8
// ,embd_inp_len >= nctx*0.6 ,current_context_tokens.size() - n_past > nctx*0.5);
// printf("csiz:%d par:%d eilen:%d np:%d",current_context_tokens.size(), (int)(nctx*0.8),embd_inp_len,n_past);
if (useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold)
if (fastforwardok && useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold)
{
// printf("curfullcontext:\n");
// print_tok_vec(current_context_tokens);
//see if smartcontext is still usable
// printf("smartctx:\n");
// print_tok_vec(smartcontext);
// printf("embinp:\n");
// print_tok_vec(embd_inp);
//see if smartcontext is still usable
auto shared = LongestCommonSubseq(smartcontext, embd_inp);
if (shared.size() > SCTokThreshold && ArrStartWith(smartcontext, shared)) //at least 32 tokens in common
{
@ -300,8 +310,6 @@ void print_tok_vec(std::vector<float> &embd)
auto trimmed = std::vector<int>(embd_inp.begin() + found, embd_inp.end());
embd_inp = trimmed;
embd_inp_len = embd_inp.size();
// printf("trimmed:\n");
// print_tok_vec(embd_inp,&vocab.id_to_token);
printf("\n[Reusing Smart Context: %d allowance remaining]", found);
int old_n_past = n_past;
@ -313,7 +321,6 @@ void print_tok_vec(std::vector<float> &embd)
for (int i = n_past; i < current_context_tokens.size(); ++i)
{
//printf("\n%s and %s\n",vocab.id_to_token[current_context_tokens[i]].c_str(), vocab.id_to_token[embd_inp[i-offset_fix]].c_str());
if (current_context_tokens[i] == embd_inp[i-offset_fix])
{
n_past += 1;
@ -331,8 +338,7 @@ void print_tok_vec(std::vector<float> &embd)
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past));
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past));
// printf("np:%d newembinp: \n",n_past);
// print_tok_vec(embd_inp);
}else{
smartcontext.clear();
}
@ -347,17 +353,16 @@ void print_tok_vec(std::vector<float> &embd)
smartcontext.clear();
}
if(useSmartContext
if(fastforwardok && useSmartContext
&& smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold
&& embd_inp_len >= SCInpLenThreshold
&& current_context_tokens.size() - n_past > SCPastLenThreshold)
{
{
//determine longest common substring after removing start part
int shiftamt = embd_inp.size() * SCTruncationRatio;
smartcontext = std::vector<int>(embd_inp.begin() + shiftamt, embd_inp.end());
printf("\n[New Smart Context Triggered! Buffered Token Allowance: %d]",shiftamt);
// printf("smartctx:\n");
// print_tok_vec(smartcontext,&vocab.id_to_token);
embd_inp = smartcontext;
//if max ctx length is exceeded, chop the prompt in half after the start part, and memorize it. The memorized part becomes LCS marker.
//when a future prompt comes in, find the LCS again. If LCS > a length and LCS starts with memorized LCS