mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
cursed hack for RNN models
This commit is contained in:
parent
0cc0ea4cf9
commit
e92f9fd422
1 changed files with 19 additions and 3 deletions
|
|
@ -630,7 +630,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
|
|||
{
|
||||
const llama_vocab * tmpvocab = llama_model_get_vocab(draftmodel);
|
||||
int draftvocab = llama_vocab_n_tokens(tmpvocab);
|
||||
if(llama_model_is_recurrent(draftmodel))
|
||||
if(llama_model_is_recurrent(draftmodel) || llama_model_is_hybrid(draftmodel))
|
||||
{
|
||||
printf("Error: Speculative decoding cannot be used with Recurrent draft models!\n");
|
||||
llama_free(draft_ctx);
|
||||
|
|
@ -2523,7 +2523,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
if(draftmodel_filename !="" && file_format==FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if(llama_model_is_recurrent(llamamodel))
|
||||
if(llama_model_is_recurrent(llamamodel) || llama_model_is_hybrid(llamamodel))
|
||||
{
|
||||
printf("Error: Speculative decoding cannot be used with Recurrent models!\n");
|
||||
}
|
||||
|
|
@ -3758,7 +3758,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
if(file_format==FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
const llama_model * mdl = llama_get_model(llama_ctx_v4);
|
||||
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl))
|
||||
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl) || file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE || file_format_meta.model_architecture==GGUFArch::ARCH_RWKV)
|
||||
{
|
||||
is_recurrent = true;
|
||||
}
|
||||
|
|
@ -3789,6 +3789,22 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
|
||||
n_past -= 1;
|
||||
}
|
||||
else if(embd_inp.size()>0 && current_context_tokens.size()>0 && last_n_tokens.size()>0)
|
||||
{
|
||||
int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0);
|
||||
if(maxedpos+2==n_past)
|
||||
{
|
||||
//kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn
|
||||
//does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token
|
||||
//it might get wrongly fast forwarded and we will get an off by 1 error.
|
||||
//todo: figure out a better way to solve this rubbish
|
||||
int tail = last_n_tokens[last_n_tokens.size()-1];
|
||||
last_n_tokens.pop_back();
|
||||
current_context_tokens.pop_back();
|
||||
n_past -=1;
|
||||
embd_inp.insert(embd_inp.begin(), 1, tail);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue