rwkv is done

This commit is contained in:
Concedo 2023-04-18 20:55:01 +08:00
parent a76b15b581
commit ea01771dd5
5 changed files with 62 additions and 12 deletions

View file

@ -232,8 +232,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
print_tok_vec(embd_inp);
//truncate to front of the prompt if its too long
int32_t nctx = params.n_ctx;
@ -258,7 +257,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
bool approved_format = (file_format==FileFormat::GPT2_2 || file_format==FileFormat::GPTJ_3);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
// bool blasmode = false;
int original_batch = params.n_batch;
@ -304,6 +303,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
else if(file_format == FileFormat::RWKV_1)
{
n_vocab = vocab.id_to_token.size(); //handled seperately
rwkv_context_v1->state_in = nullptr;
}
else
{
@ -333,9 +333,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(file_format==FileFormat::RWKV_1)
{
printf("\nsiz:%d val:%d\n",embd.size(),embd[0]);
evalres = rwkv_eval(rwkv_context_v1, embd[0], rwkv_context_v1->state_in, rwkv_context_v1->state_out, rwkv_context_v1->logits_out);
memcpy(logits.data(), rwkv_context_v1->logits_out, sizeof(float)*rwkv_vocab.size());
rwkv_context_v1->state_in = rwkv_context_v1->state_out;
}
else if(file_format==FileFormat::GPT2_1)
{