mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 18:09:42 +00:00
fix for jamba models - they have recurrent layers like rwkv, so context shifting and forwarding wont work on them.
This commit is contained in:
parent
e9473305d0
commit
5a3b2e3921
3 changed files with 12 additions and 7 deletions
|
@ -480,9 +480,9 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tok
|
|||
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
|
||||
return;
|
||||
}
|
||||
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
|
||||
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
|
||||
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
|
||||
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA
|
||||
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV || file_format_meta.model_architecture==GGUFArch::ARCH_JAMBA));
|
||||
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
|
||||
{
|
||||
printf("\nWARNING: RNN models do not support context rewind!\n");
|
||||
return;
|
||||
|
@ -3644,11 +3644,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("%s\n", RemoveBell(outstr).c_str());
|
||||
}
|
||||
|
||||
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
|
||||
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
|
||||
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA
|
||||
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV || file_format_meta.model_architecture==GGUFArch::ARCH_JAMBA));
|
||||
bool blank_prompt = (addedmemory=="" && kcpp_data->prompt=="");
|
||||
|
||||
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
|
||||
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
|
||||
{
|
||||
if(!blank_prompt)
|
||||
{
|
||||
|
@ -3657,7 +3657,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true);
|
||||
}
|
||||
}
|
||||
if(is_mamba || is_rwkv_new)
|
||||
if(is_recurrent)
|
||||
{
|
||||
if(n_past==0)
|
||||
{
|
||||
|
|
|
@ -324,6 +324,10 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
{
|
||||
fileformatmeta->model_architecture = GGUFArch::ARCH_MAMBA;
|
||||
}
|
||||
else if(modelarch=="jamba")
|
||||
{
|
||||
fileformatmeta->model_architecture = GGUFArch::ARCH_JAMBA;
|
||||
}
|
||||
else if(modelarch=="llama" && freq_base_train==10000.0f && (n_tensors==435 || n_tensors==611))
|
||||
{
|
||||
fileformatmeta->model_architecture = GGUFArch::ARCH_SOLAR;
|
||||
|
|
|
@ -63,6 +63,7 @@ enum GGUFArch
|
|||
ARCH_GEMMA3 = 8,
|
||||
ARCH_GLM4 = 9,
|
||||
ARCH_GEMMA3N = 10,
|
||||
ARCH_JAMBA = 11,
|
||||
};
|
||||
|
||||
struct FileFormatExtraMeta
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue