cursed hack for RNN models

This commit is contained in:
Concedo 2025-10-11 23:14:55 +08:00
parent 0cc0ea4cf9
commit e92f9fd422

View file

@ -630,7 +630,7 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
{
const llama_vocab * tmpvocab = llama_model_get_vocab(draftmodel);
int draftvocab = llama_vocab_n_tokens(tmpvocab);
if(llama_model_is_recurrent(draftmodel))
if(llama_model_is_recurrent(draftmodel) || llama_model_is_hybrid(draftmodel))
{
printf("Error: Speculative decoding cannot be used with Recurrent draft models!\n");
llama_free(draft_ctx);
@ -2523,7 +2523,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
if(draftmodel_filename !="" && file_format==FileFormat::GGUF_GENERIC)
{
if(llama_model_is_recurrent(llamamodel))
if(llama_model_is_recurrent(llamamodel) || llama_model_is_hybrid(llamamodel))
{
printf("Error: Speculative decoding cannot be used with Recurrent models!\n");
}
@ -3758,7 +3758,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(file_format==FileFormat::GGUF_GENERIC)
{
const llama_model * mdl = llama_get_model(llama_ctx_v4);
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl))
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl) || file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE || file_format_meta.model_architecture==GGUFArch::ARCH_RWKV)
{
is_recurrent = true;
}
@ -3789,6 +3789,22 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
n_past -= 1;
}
else if(embd_inp.size()>0 && current_context_tokens.size()>0 && last_n_tokens.size()>0)
{
int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0);
if(maxedpos+2==n_past)
{
//kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn
//does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token
//it might get wrongly fast forwarded and we will get an off by 1 error.
//todo: figure out a better way to solve this rubbish
int tail = last_n_tokens[last_n_tokens.size()-1];
last_n_tokens.pop_back();
current_context_tokens.pop_back();
n_past -=1;
embd_inp.insert(embd_inp.begin(), 1, tail);
}
}
}
}
else