wip integrating new rwkv

This commit is contained in:
Concedo 2023-05-27 22:45:28 +08:00
parent fe63bfdb0f
commit 55e0fbf024
6 changed files with 1239 additions and 247 deletions

View file

@ -23,6 +23,7 @@
#include "gpt2_v2.cpp"
#include "gpt2_v3.cpp"
#include "rwkv_v2.cpp"
#include "rwkv_v3.cpp"
#include "neox_v2.cpp"
#include "neox_v3.cpp"
@ -43,7 +44,7 @@ static gpt2_model gpt2_ctx_v3;
static gpt_neox_v2_model neox_ctx_v2;
static gpt_neox_model neox_ctx_v3;
static rwkv_context * rwkv_ctx_v1;
static rwkv_v2_context * rwkv_ctx_v2;
static llama_v2_context_params llama_ctx_params_v2;
static llama_context_params llama_ctx_params;
static llama_v2_context * llama_ctx_v2;
@ -390,17 +391,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else if (file_format == FileFormat::RWKV_1)
{
rwkv_ctx_v1 = rwkv_init_from_file(modelname.c_str(), n_threads);
rwkv_ctx_v2 = rwkv_v2_init_from_file(modelname.c_str(), n_threads);
//setup buffers for rwkv state
auto padding = 512u;
auto statebufsiz = rwkv_get_state_buffer_element_count(rwkv_ctx_v1) * sizeof(float) + padding;
auto logitbufsiz = rwkv_get_logits_buffer_element_count(rwkv_ctx_v1) * sizeof(float) + padding;
auto statebufsiz = rwkv_v2_get_state_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
auto logitbufsiz = rwkv_v2_get_logits_buffer_element_count(rwkv_ctx_v2) * sizeof(float) + padding;
printf("\nRWKV Init: State Buffer:%u, Logit Buffer:%u\n", statebufsiz, logitbufsiz);
rwkv_ctx_v1->state_out = (float *)malloc(statebufsiz);
rwkv_ctx_v1->logits_out = (float *)malloc(logitbufsiz);
rwkv_ctx_v1->state_in = nullptr;
rwkv_ctx_v2->state_out = (float *)malloc(statebufsiz);
rwkv_ctx_v2->logits_out = (float *)malloc(logitbufsiz);
rwkv_ctx_v2->state_in = nullptr;
n_batch = 1;
std::string word;
@ -414,15 +415,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
printf("\nRWKV Vocab: %u\n",vocabsiz);
bool testeval = rwkv_eval(rwkv_ctx_v1, 0, rwkv_ctx_v1->state_in, rwkv_ctx_v1->state_out, rwkv_ctx_v1->logits_out);
bool testeval = rwkv_v2_eval(rwkv_ctx_v2, 0, rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
if(!testeval)
{
printf("\nError: RWKV Init Eval Failed!\n");
}
logits.resize(vocabsiz);
memcpy(logits.data(), rwkv_ctx_v1->logits_out, sizeof(float)*vocabsiz);
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*vocabsiz);
if (rwkv_ctx_v1 == NULL)
if (rwkv_ctx_v2 == NULL)
{
return ModelLoadResult::FAIL;
}
@ -838,11 +839,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
n_vocab = vocab.id_to_token.size(); //handled seperately
if(n_past==0)
{
rwkv_ctx_v1->state_in = nullptr;
rwkv_ctx_v2->state_in = nullptr;
}
else
{
rwkv_ctx_v1->state_in = rwkv_ctx_v1->state_out;
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
//if it's empty, push in the final previous token
if(embd_inp.size()==0 && current_context_tokens.size()>0)
{
@ -910,9 +911,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else if(file_format==FileFormat::RWKV_1)
{
evalres = rwkv_eval(rwkv_ctx_v1, embd[0], rwkv_ctx_v1->state_in, rwkv_ctx_v1->state_out, rwkv_ctx_v1->logits_out);
memcpy(logits.data(), rwkv_ctx_v1->logits_out, sizeof(float)*rwkv_vocab.size());
rwkv_ctx_v1->state_in = rwkv_ctx_v1->state_out;
evalres = rwkv_v2_eval(rwkv_ctx_v2, embd[0], rwkv_ctx_v2->state_in, rwkv_ctx_v2->state_out, rwkv_ctx_v2->logits_out);
memcpy(logits.data(), rwkv_ctx_v2->logits_out, sizeof(float)*rwkv_vocab.size());
rwkv_ctx_v2->state_in = rwkv_ctx_v2->state_out;
}
else if(file_format==FileFormat::GPT2_1)
{