mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
allow rwkv6 to run although its broken
This commit is contained in:
parent
b63158005f
commit
fc7fe2e7a0
3 changed files with 25 additions and 8 deletions
|
@ -194,15 +194,18 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
|
|||
if(add_bos)
|
||||
{
|
||||
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
|
||||
if(output_tokens.size()==0)
|
||||
if(bostoadd != LLAMA_TOKEN_NULL) //if bos does not exist, do not add it
|
||||
{
|
||||
output_tokens.push_back(bostoadd);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(output_tokens[0]!=bostoadd)
|
||||
if(output_tokens.size()==0)
|
||||
{
|
||||
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
|
||||
output_tokens.push_back(bostoadd);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(output_tokens[0]!=bostoadd)
|
||||
{
|
||||
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1870,6 +1873,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
}
|
||||
|
||||
if(file_format_meta.model_architecture==GGUFArch::ARCH_RWKV)
|
||||
{
|
||||
printf("\nRWKV6 Overriding EOS and BOS IDs to 0\n");
|
||||
llamamodel->vocab.special_bos_id = llamamodel->vocab.special_eos_id = 0;
|
||||
}
|
||||
|
||||
llama_ctx_params.flash_attn = kcpp_params->flash_attn;
|
||||
llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
|
||||
llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
|
||||
|
@ -3085,7 +3094,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
|
||||
{
|
||||
// set the logit of the eos token to very low to avoid sampling it
|
||||
logitsPtr[eosID] = lowestLogit;
|
||||
if(eosID!=LLAMA_TOKEN_NULL)
|
||||
{
|
||||
logitsPtr[eosID] = lowestLogit;
|
||||
}
|
||||
if(eotID!=-1)
|
||||
{
|
||||
logitsPtr[eotID] = lowestLogit;
|
||||
|
|
|
@ -314,6 +314,10 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
{
|
||||
fileformatmeta->model_architecture = GGUFArch::ARCH_QWEN2;
|
||||
}
|
||||
else if(modelarch=="rwkv6")
|
||||
{
|
||||
fileformatmeta->model_architecture = GGUFArch::ARCH_RWKV;
|
||||
}
|
||||
printf("Arch Category: %d\n",fileformatmeta->model_architecture);
|
||||
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ enum GGUFArch
|
|||
ARCH_MAMBA = 3,
|
||||
ARCH_SOLAR = 4,
|
||||
ARCH_QWEN2 = 5,
|
||||
ARCH_RWKV = 6,
|
||||
};
|
||||
|
||||
struct FileFormatExtraMeta
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue