mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
try out the new rwkv but it seems worse, may revert
This commit is contained in:
parent
632bf27b65
commit
e1a7042943
4 changed files with 825 additions and 370 deletions
|
@ -431,6 +431,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
else //rwkv_2
|
||||
{
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
if(inputs.gpulayers>0)
|
||||
{
|
||||
rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
|
||||
}
|
||||
|
||||
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
|
||||
const size_t n_vocab = header.n_vocab;
|
||||
printf("\nDetected Vocab: %d",n_vocab);
|
||||
|
@ -811,7 +817,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
params.top_k = 120; //to disable top_k we actually need to increase this value to a very high number
|
||||
}
|
||||
if (params.seed <= 0)
|
||||
if (params.seed <= 0 || params.seed==0xFFFFFFFF)
|
||||
{
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
@ -1060,14 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else
|
||||
{
|
||||
if(embd.size()>1)
|
||||
{
|
||||
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
// if(embd.size()>1)
|
||||
// {
|
||||
// evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
||||
//}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue