try out the new rwkv but it seems worse, may revert

This commit is contained in:
Concedo 2023-07-02 00:10:56 +08:00
parent 632bf27b65
commit e1a7042943
4 changed files with 825 additions and 370 deletions

View file

@ -431,6 +431,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
else //rwkv_2
{
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
if(inputs.gpulayers>0)
{
rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
}
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
const size_t n_vocab = header.n_vocab;
printf("\nDetected Vocab: %d",n_vocab);
@ -811,7 +817,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
params.top_k = 120; //to disable top_k we actually need to increase this value to a very high number
}
if (params.seed <= 0)
if (params.seed <= 0 || params.seed==0xFFFFFFFF)
{
params.seed = time(NULL);
}
@ -1060,14 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else
{
if(embd.size()>1)
{
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
else
{
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
}
// if(embd.size()>1)
// {
// evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
// }
// else
// {
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
//}
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;