fix for jamba models - they have recurrent layers like rwkv, so context shifting and forwarding wont work on them.

This commit is contained in:
Concedo 2025-07-12 10:05:15 +08:00
parent e9473305d0
commit 5a3b2e3921
3 changed files with 12 additions and 7 deletions

View file

@ -480,9 +480,9 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
return;
}
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV || file_format_meta.model_architecture==GGUFArch::ARCH_JAMBA));
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
{
printf("\nWARNING: RNN models do not support context rewind!\n");
return;
@ -3644,11 +3644,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("%s\n", RemoveBell(outstr).c_str());
}
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV || file_format_meta.model_architecture==GGUFArch::ARCH_JAMBA));
bool blank_prompt = (addedmemory=="" && kcpp_data->prompt=="");
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
{
if(!blank_prompt)
{
@ -3657,7 +3657,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
}
}
if(is_mamba || is_rwkv_new)
if(is_recurrent)
{
if(n_past==0)
{