fast forwarding for rwkv for unmodified contexts

This commit is contained in:
Concedo 2023-04-19 15:09:35 +08:00
parent f39def81d4
commit 45ec09d31b
8 changed files with 70 additions and 46 deletions

View file

@ -251,9 +251,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0;
if(file_format!=FileFormat::RWKV_1)
if (file_format == FileFormat::RWKV_1)
{
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext);
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
}
else
{
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext, false);
}
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
@ -303,7 +307,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
else if(file_format == FileFormat::RWKV_1)
{
n_vocab = vocab.id_to_token.size(); //handled seperately
rwkv_context_v1->state_in = nullptr;
if(n_past==0)
{
rwkv_context_v1->state_in = nullptr;
}
else
{
rwkv_context_v1->state_in = rwkv_context_v1->state_out;
//if it's empty, push in the final previous token
if(embd_inp.size()==0 && current_context_tokens.size()>0)
{
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
}
}
}
else
{