mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
wip integrating new rwkv
This commit is contained in:
parent
fe63bfdb0f
commit
55e0fbf024
6 changed files with 1239 additions and 247 deletions
|
@ -23,6 +23,7 @@
|
|||
#include "gpt2_v2.cpp"
|
||||
#include "gpt2_v3.cpp"
|
||||
#include "rwkv_v2.cpp"
|
||||
#include "rwkv_v3.cpp"
|
||||
#include "neox_v2.cpp"
|
||||
#include "neox_v3.cpp"
|
||||
|
||||
|
@ -43,7 +44,7 @@ static gpt2_model gpt2_ctx_v3;
|
|||
static gpt_neox_v2_model neox_ctx_v2;
|
||||
static gpt_neox_model neox_ctx_v3;
|
||||
|
||||
static rwkv_context * rwkv_ctx_v1;
|
||||
static rwkv_v2_context * rwkv_ctx_v2;
|
||||
static llama_v2_context_params llama_ctx_params_v2;
|
||||
static llama_context_params llama_ctx_params;
|
||||
static llama_v2_context * llama_ctx_v2;
|
||||
|
@ -390,17 +391,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
else if (file_format == FileFormat::RWKV_1)
|
||||
{
|
||||
rwkv_ctx_v1 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
//setup buffers for rwkv state
|
||||
auto padding = 512u;
|
||||
auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v1) * sizeof(float) + padding;
|
||||
auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v1) * sizeof(float) + padding;
|
||||
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
|
||||
|
||||
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
|
||||
rwkv_ctx_v1->state_out = (float *)malloc(statebufsiz);
|
||||
rwkv_ctx_v1->logits_out = (float *)malloc(logitbufsiz);
|
||||
rwkv_ctx_v1->state_in = nullptr;
|
||||
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
|
||||
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
|
||||
rwkv_ctx_v2->state_in = nullptr;
|
||||
n_batch = 1;
|
||||
|
||||
std::string word;
|
||||
|
@ -414,15 +415,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
printf("\nRWKV Vocab: %u\n",vocabsiz);
|
||||
|
||||
bool testeval = rwkv_eval(rwkv_ctx_v1, 0, rwkv_ctx_v1->state_in, rwkv_ctx_v1->state_out, rwkv_ctx_v1->logits_out);
|
||||
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||
if(!testeval)
|
||||
{
|
||||
printf("\nError: RWKV Init Eval Failed!\n");
|
||||
}
|
||||
logits.resize(vocabsiz);
|
||||
memcpy(logits.data(), rwkv_ctx_v1->logits_out, sizeof(float)*vocabsiz);
|
||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz);
|
||||
|
||||
if (rwkv_ctx_v1 == NULL)
|
||||
if (rwkv_ctx_v2 == NULL)
|
||||
{
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
|
@ -838,11 +839,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||
if(n_past==0)
|
||||
{
|
||||
rwkv_ctx_v1->state_in = nullptr;
|
||||
rwkv_ctx_v2->state_in = nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
rwkv_ctx_v1->state_in = rwkv_ctx_v1->state_out;
|
||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
//if it's empty, push in the final previous token
|
||||
if(embd_inp.size()==0 && current_context_tokens.size()>0)
|
||||
{
|
||||
|
@ -910,9 +911,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format==FileFormat::RWKV_1)
|
||||
{
|
||||
evalres = rwkv_eval(rwkv_ctx_v1, embd[0], rwkv_ctx_v1->state_in, rwkv_ctx_v1->state_out, rwkv_ctx_v1->logits_out);
|
||||
memcpy(logits.data(), rwkv_ctx_v1->logits_out, sizeof(float)*rwkv_vocab.size());
|
||||
rwkv_ctx_v1->state_in = rwkv_ctx_v1->state_out;
|
||||
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
|
||||
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size());
|
||||
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
|
||||
}
|
||||
else if(file_format==FileFormat::GPT2_1)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue