From 00a686fc7202e66e2bd56ac7f0be0b230a5a5590 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Tue, 10 Dec 2024 22:37:40 +0800 Subject: [PATCH] fixed fast forwarding context corruption after abort during prompt processing --- gpttype_adapter.cpp | 48 +++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 0e17874db..1c804a75b 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -120,7 +120,7 @@ static std::vector dry_repeat_count; // Indexed as last_n_tokens static std::unordered_map dry_max_token_repeat; static std::vector top_picks_history; static int remaining_tokens = 0; -static int stopper_unused_tokens = 0; +static bool early_abort = false; static std::mutex concat_output_mtx; static std::string concat_output = ""; static std::string concat_output_reader_copy_poll = ""; //for streaming @@ -2666,8 +2666,7 @@ bool gpttype_generate_abort() { printf("\nWarning: KCPP text generation not initialized!\n"); } - stopper_unused_tokens = remaining_tokens; - remaining_tokens = 0; + early_abort = true; return true; } @@ -2801,6 +2800,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) dry_sequence_breakers.clear(); dry_max_token_repeat.clear(); top_picks_history.clear(); + early_abort = false; double time0 = 0, time1 = 0, time2 = 0; timer_start(); @@ -3280,7 +3280,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs) current_context_tokens.resize(n_past); remaining_tokens = kcpp_data->n_predict; - stopper_unused_tokens = 0; int input_consumed = 0; std::mt19937 rng(kcpp_data->seed); @@ -3368,7 +3367,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) printf("%s\n\n", RemoveBell(outstr).c_str()); } - while (remaining_tokens > 0) + while (remaining_tokens > 0 && !early_abort) { gpt_vocab::id id = 0; // predict @@ -3492,7 +3491,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs) n_past += embd.size(); embd.clear(); - if ((int)embd_inp.size() <= input_consumed) + + if (!early_abort && (int)embd_inp.size() <= input_consumed) //if decoding was aborted, DO NOT perform any sampling { // out of user input, sample next token const float top_k = kcpp_data->top_k; @@ -3740,46 +3740,43 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } } - bool earlystopped = false; - if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1))) + if(!early_abort) { - stopper_unused_tokens = remaining_tokens; - if(allow_regular_prints) + if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1))) { - printf("\n(EOS token triggered! ID:%d)",id); + if(allow_regular_prints) + { + printf("\n(EOS token triggered! ID:%d)",id); + } + early_abort = true; + last_stop_reason = stop_reason::EOS_TOKEN_HIT; } - remaining_tokens = 0; - last_stop_reason = stop_reason::EOS_TOKEN_HIT; - earlystopped = true; } - if(!earlystopped) + if(!early_abort) { for (const auto &matched : special_stop_sequence) { if(id==matched) { - stopper_unused_tokens = remaining_tokens; if(allow_regular_prints) { printf("\n(Special Stop Token Triggered! ID:%d)",matched); } - remaining_tokens = 0; + early_abort = true; last_stop_reason = stop_reason::EOS_TOKEN_HIT; - earlystopped = true; break; } } } - if(!earlystopped) + if(!early_abort) { for (const auto &matched : stop_sequence) { if (concat_output.find(matched) != std::string::npos) { - stopper_unused_tokens = remaining_tokens; - remaining_tokens = 0; + early_abort = true; if(allow_regular_prints) { auto match_clean = matched; @@ -3787,7 +3784,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs) printf("\n(Stop sequence triggered: %s)", match_clean.c_str()); } last_stop_reason = stop_reason::CUSTOM_STOPPER; - earlystopped = true; break; } } @@ -3807,7 +3803,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) fflush(stdout); } - else + else if(!early_abort) //do not ingest prompt if aborted! { // some user input remains from prompt or interaction, forward it to processing while ((int)embd_inp.size() > input_consumed) @@ -3926,10 +3922,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) time2 = timer_check(); float pt1 = (time1*1000.0/(embd_inp.size()==0?1:embd_inp.size())); float ts1 = (1000.0/pt1); - int realnpredict = kcpp_data->n_predict-stopper_unused_tokens; - float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict)); + int realnpredict = kcpp_data->n_predict-remaining_tokens; + float pt2 = (time2*1000.0/(realnpredict<=0?1:realnpredict)); float ts2 = (1000.0/pt2); - float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2)); + float tokens_per_second = (realnpredict <= 0 ? 0 : realnpredict / (time1 + time2)); printf("\n[%s] CtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",get_timestamp_str().c_str(),(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second); fflush(stdout); output.status = 1;