From fc7fe2e7a0cbdbf48c5c5e74e559183c9c4aa2d1 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:50:58 +0800 Subject: [PATCH] allow rwkv6 to run although its broken --- gpttype_adapter.cpp | 28 ++++++++++++++++++++-------- model_adapter.cpp | 4 ++++ model_adapter.h | 1 + 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 02aed8120..e35475545 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -194,15 +194,18 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector if(add_bos) { llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model)); - if(output_tokens.size()==0) + if(bostoadd != LLAMA_TOKEN_NULL) //if bos does not exist, do not add it { - output_tokens.push_back(bostoadd); - } - else - { - if(output_tokens[0]!=bostoadd) + if(output_tokens.size()==0) { - output_tokens.insert(output_tokens.begin(), 1, bostoadd); + output_tokens.push_back(bostoadd); + } + else + { + if(output_tokens[0]!=bostoadd) + { + output_tokens.insert(output_tokens.begin(), 1, bostoadd); + } } } } @@ -1870,6 +1873,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } } + if(file_format_meta.model_architecture==GGUFArch::ARCH_RWKV) + { + printf("\nRWKV6 Overriding EOS and BOS IDs to 0\n"); + llamamodel->vocab.special_bos_id = llamamodel->vocab.special_eos_id = 0; + } + llama_ctx_params.flash_attn = kcpp_params->flash_attn; llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16)); llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16)); @@ -3085,7 +3094,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) if (!inputs.allow_eos_token && !inputs.bypass_eos_token) { // set the logit of the eos token to very low to avoid sampling it - logitsPtr[eosID] = lowestLogit; + if(eosID!=LLAMA_TOKEN_NULL) + { + logitsPtr[eosID] = lowestLogit; + } if(eotID!=-1) { logitsPtr[eotID] = lowestLogit; diff --git a/model_adapter.cpp b/model_adapter.cpp index ecde7d191..2dd7c9d08 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -314,6 +314,10 @@ void print_tok_vec(std::vector &embd) { fileformatmeta->model_architecture = GGUFArch::ARCH_QWEN2; } + else if(modelarch=="rwkv6") + { + fileformatmeta->model_architecture = GGUFArch::ARCH_RWKV; + } printf("Arch Category: %d\n",fileformatmeta->model_architecture); } diff --git a/model_adapter.h b/model_adapter.h index 591971562..7364710a6 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -58,6 +58,7 @@ enum GGUFArch ARCH_MAMBA = 3, ARCH_SOLAR = 4, ARCH_QWEN2 = 5, + ARCH_RWKV = 6, }; struct FileFormatExtraMeta