incomplete fix for rnn models, load state works but logits slightly different

This commit is contained in:
Concedo 2026-02-28 11:52:24 +08:00
parent 14d82bb38e
commit dd08d675f2

View file

@ -4204,38 +4204,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
llama_memory_clear(llama_get_memory(draft_ctx),true);
}
}
else if(embd_inp.size()==0)
else
{
embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
current_context_tokens.pop_back();
n_past -= 1;
//another dirty hack
int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0);
if(maxedpos==n_past)
if(current_context_tokens.size()>0 && last_n_tokens.size()>0)
{
n_past += 1;
}
}
else if(embd_inp.size()>0 && current_context_tokens.size()>0 && last_n_tokens.size()>0)
{
int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0);
if(maxedpos+2==n_past)
{
//kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn
//does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token
//it might get wrongly fast forwarded and we will get an off by 1 error.
//todo: figure out a better way to solve this rubbish
int tail = last_n_tokens[last_n_tokens.size()-1];
last_n_tokens.pop_back();
current_context_tokens.pop_back();
n_past -=1;
embd_inp.insert(embd_inp.begin(), 1, tail);
}
else if(maxedpos==n_past)
{
n_past += 1;
int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0);
if(maxedpos+2==n_past)
{
//kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn
//does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token
//it might get wrongly fast forwarded and we will get an off by 1 error.
//todo: figure out a better way to solve this rubbish
int tail = last_n_tokens[last_n_tokens.size()-1];
last_n_tokens.pop_back();
current_context_tokens.pop_back();
n_past -=1;
embd_inp.insert(embd_inp.begin(), 1, tail);
}
else if(maxedpos==n_past)
{
n_past += 1;
}
//it is generally preferable to not have the embd_inp array empty unless doing so would cause an error
// if(embd_inp.size()==0 && current_context_tokens.size()>0)
// {
// embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]);
// current_context_tokens.pop_back();
// n_past -= 1;
// }
}
}
}
}
else
@ -4272,7 +4271,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
bool blasmode = (embd_inp.size() >= 32 && kcpp_cpu_has_blas() && kcpp_data->n_batch>=32);
current_context_tokens.resize(n_past);
if(current_context_tokens.size()>n_past)
{
current_context_tokens.resize(n_past);
}
remaining_tokens = kcpp_data->n_predict;
int input_consumed = 0;
@ -4304,6 +4306,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
bool startedsampling = false;
bool firstdecodedone = false; //we CANNOT use logits if the first decode has not been executed yet.
bool v3_use_scratch = true; //for normal inference always use scratch
speculative_draft_result draft_results; //only use if drafting was used
@ -4362,8 +4365,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
std::string outstr = "";
// printf("\n[Debug: Dump Forwarded Input Tokens]\n");
// outstr += get_tok_vec_str(embd_inp);
outstr += "[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]";
//outstr += get_tok_vec_str(current_context_tokens);
// outstr += "\n";
outstr += "[Debug: embd_inp="+std::to_string(embd_inp.size())+" n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]";
// outstr += "\n";
// outstr += get_tok_vec_str(current_context_tokens);
printf("%s\n\n", RemoveBell(outstr).c_str());
}
@ -4524,6 +4529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
generation_finished = true;
return output;
}
firstdecodedone = true;
}
n_past += embd.size();
@ -4591,6 +4597,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
while(logits_sampled<logits_to_sample && remaining_tokens>0 && !abort_draft && !early_abort)
{
if(!firstdecodedone && current_context_tokens.size()>0)
{
embd.clear();
embd.push_back(current_context_tokens[current_context_tokens.size()-1]);
break;
}
if(logits_sampled>0)
{
//this is not the first loop, so we need to increment some things
@ -5194,7 +5206,7 @@ size_t gpttype_save_state_kv(int slot)
}
}
touch_slot(slot);
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,savestates[slot].savestate_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
}
if(draft_ctx)