diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 490d887c5..66f120187 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -4204,38 +4204,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs) llama_memory_clear(llama_get_memory(draft_ctx),true); } } - else if(embd_inp.size()==0) + else { - embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]); - current_context_tokens.pop_back(); - n_past -= 1; - //another dirty hack - int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0); - if(maxedpos==n_past) + if(current_context_tokens.size()>0 && last_n_tokens.size()>0) { - n_past += 1; - } - } - else if(embd_inp.size()>0 && current_context_tokens.size()>0 && last_n_tokens.size()>0) - { - int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0); - if(maxedpos+2==n_past) - { - //kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn - //does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token - //it might get wrongly fast forwarded and we will get an off by 1 error. - //todo: figure out a better way to solve this rubbish - int tail = last_n_tokens[last_n_tokens.size()-1]; - last_n_tokens.pop_back(); - current_context_tokens.pop_back(); - n_past -=1; - embd_inp.insert(embd_inp.begin(), 1, tail); - } - else if(maxedpos==n_past) - { - n_past += 1; + int maxedpos = llama_memory_seq_pos_max(llama_get_memory(llama_ctx_v4),0); + if(maxedpos+2==n_past) + { + //kcpp: a very dirty hack for rnn models. this happens because the very last token of the last turn + //does not actually get processed but is still added to current_context_tokens. if the instruct start tag starts with that same token + //it might get wrongly fast forwarded and we will get an off by 1 error. + //todo: figure out a better way to solve this rubbish + int tail = last_n_tokens[last_n_tokens.size()-1]; + last_n_tokens.pop_back(); + current_context_tokens.pop_back(); + n_past -=1; + embd_inp.insert(embd_inp.begin(), 1, tail); + } + else if(maxedpos==n_past) + { + n_past += 1; + } + //it is generally preferable to not have the embd_inp array empty unless doing so would cause an error + // if(embd_inp.size()==0 && current_context_tokens.size()>0) + // { + // embd_inp.push_back(current_context_tokens[current_context_tokens.size()-1]); + // current_context_tokens.pop_back(); + // n_past -= 1; + // } } } + } } else @@ -4272,7 +4271,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) bool blasmode = (embd_inp.size() >= 32 && kcpp_cpu_has_blas() && kcpp_data->n_batch>=32); - current_context_tokens.resize(n_past); + if(current_context_tokens.size()>n_past) + { + current_context_tokens.resize(n_past); + } remaining_tokens = kcpp_data->n_predict; int input_consumed = 0; @@ -4304,6 +4306,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } bool startedsampling = false; + bool firstdecodedone = false; //we CANNOT use logits if the first decode has not been executed yet. bool v3_use_scratch = true; //for normal inference always use scratch speculative_draft_result draft_results; //only use if drafting was used @@ -4362,8 +4365,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) std::string outstr = ""; // printf("\n[Debug: Dump Forwarded Input Tokens]\n"); // outstr += get_tok_vec_str(embd_inp); - outstr += "[Debug: n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]"; - //outstr += get_tok_vec_str(current_context_tokens); + // outstr += "\n"; + outstr += "[Debug: embd_inp="+std::to_string(embd_inp.size())+" n_past="+std::to_string(n_past)+" Context Size = " + std::to_string(current_context_tokens.size()) + "]"; + // outstr += "\n"; + // outstr += get_tok_vec_str(current_context_tokens); printf("%s\n\n", RemoveBell(outstr).c_str()); } @@ -4524,6 +4529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) generation_finished = true; return output; } + firstdecodedone = true; } n_past += embd.size(); @@ -4591,6 +4597,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } while(logits_sampled0 && !abort_draft && !early_abort) { + if(!firstdecodedone && current_context_tokens.size()>0) + { + embd.clear(); + embd.push_back(current_context_tokens[current_context_tokens.size()-1]); + break; + } if(logits_sampled>0) { //this is not the first loop, so we need to increment some things @@ -5194,7 +5206,7 @@ size_t gpttype_save_state_kv(int slot) } } touch_slot(slot); - printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024)); + printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,savestates[slot].savestate_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024)); } if(draft_ctx)