allow rwkv6 to run although its broken

This commit is contained in:
Concedo 2024-09-09 20:50:58 +08:00
parent b63158005f
commit fc7fe2e7a0
3 changed files with 25 additions and 8 deletions

View file

@ -194,6 +194,8 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
if(add_bos) if(add_bos)
{ {
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model)); llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
if(bostoadd != LLAMA_TOKEN_NULL) //if bos does not exist, do not add it
{
if(output_tokens.size()==0) if(output_tokens.size()==0)
{ {
output_tokens.push_back(bostoadd); output_tokens.push_back(bostoadd);
@ -208,6 +210,7 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
} }
} }
} }
}
else else
{ {
// tokenize the prompt // tokenize the prompt
@ -1870,6 +1873,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
} }
} }
if(file_format_meta.model_architecture==GGUFArch::ARCH_RWKV)
{
printf("\nRWKV6 Overriding EOS and BOS IDs to 0\n");
llamamodel->vocab.special_bos_id = llamamodel->vocab.special_eos_id = 0;
}
llama_ctx_params.flash_attn = kcpp_params->flash_attn; llama_ctx_params.flash_attn = kcpp_params->flash_attn;
llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16)); llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16)); llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
@ -3085,7 +3094,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (!inputs.allow_eos_token && !inputs.bypass_eos_token) if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
{ {
// set the logit of the eos token to very low to avoid sampling it // set the logit of the eos token to very low to avoid sampling it
if(eosID!=LLAMA_TOKEN_NULL)
{
logitsPtr[eosID] = lowestLogit; logitsPtr[eosID] = lowestLogit;
}
if(eotID!=-1) if(eotID!=-1)
{ {
logitsPtr[eotID] = lowestLogit; logitsPtr[eotID] = lowestLogit;

View file

@ -314,6 +314,10 @@ void print_tok_vec(std::vector<float> &embd)
{ {
fileformatmeta->model_architecture = GGUFArch::ARCH_QWEN2; fileformatmeta->model_architecture = GGUFArch::ARCH_QWEN2;
} }
else if(modelarch=="rwkv6")
{
fileformatmeta->model_architecture = GGUFArch::ARCH_RWKV;
}
printf("Arch Category: %d\n",fileformatmeta->model_architecture); printf("Arch Category: %d\n",fileformatmeta->model_architecture);
} }

View file

@ -58,6 +58,7 @@ enum GGUFArch
ARCH_MAMBA = 3, ARCH_MAMBA = 3,
ARCH_SOLAR = 4, ARCH_SOLAR = 4,
ARCH_QWEN2 = 5, ARCH_QWEN2 = 5,
ARCH_RWKV = 6,
}; };
struct FileFormatExtraMeta struct FileFormatExtraMeta