mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 00:54:41 +00:00
fast forwarding for rwkv for unmodified contexts
This commit is contained in:
parent
f39def81d4
commit
45ec09d31b
8 changed files with 70 additions and 46 deletions
|
@ -251,9 +251,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
n_past = 0;
|
||||
|
||||
if(file_format!=FileFormat::RWKV_1)
|
||||
if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext);
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext, false);
|
||||
}
|
||||
|
||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||
|
@ -303,7 +307,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
else if(file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||
rwkv_context_v1->state_in = nullptr;
|
||||
if(n_past==0)
|
||||
{
|
||||
rwkv_context_v1->state_in = nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
rwkv_context_v1->state_in = rwkv_context_v1->state_out;
|
||||
//if it's empty, push in the final previous token
|
||||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||
{
|
||||
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue