fixed fast forwarding context corruption after abort during prompt processing

This commit is contained in:
Concedo 2024-12-10 22:37:40 +08:00
parent a11bba5893
commit 00a686fc72

View file

@ -120,7 +120,7 @@ static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
static std::unordered_map<gpt_vocab::id, int> dry_max_token_repeat;
static std::vector<TopPicksData> top_picks_history;
static int remaining_tokens = 0;
static int stopper_unused_tokens = 0;
static bool early_abort = false;
static std::mutex concat_output_mtx;
static std::string concat_output = "";
static std::string concat_output_reader_copy_poll = ""; //for streaming
@ -2666,8 +2666,7 @@ bool gpttype_generate_abort()
{
printf("\nWarning: KCPP text generation not initialized!\n");
}
stopper_unused_tokens = remaining_tokens;
remaining_tokens = 0;
early_abort = true;
return true;
}
@ -2801,6 +2800,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
dry_sequence_breakers.clear();
dry_max_token_repeat.clear();
top_picks_history.clear();
early_abort = false;
double time0 = 0, time1 = 0, time2 = 0;
timer_start();
@ -3280,7 +3280,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
current_context_tokens.resize(n_past);
remaining_tokens = kcpp_data->n_predict;
stopper_unused_tokens = 0;
int input_consumed = 0;
std::mt19937 rng(kcpp_data->seed);
@ -3368,7 +3367,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("%s\n\n", RemoveBell(outstr).c_str());
}
while (remaining_tokens > 0)
while (remaining_tokens > 0 && !early_abort)
{
gpt_vocab::id id = 0;
// predict
@ -3492,7 +3491,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
n_past += embd.size();
embd.clear();
if ((int)embd_inp.size() <= input_consumed)
if (!early_abort && (int)embd_inp.size() <= input_consumed) //if decoding was aborted, DO NOT perform any sampling
{
// out of user input, sample next token
const float top_k = kcpp_data->top_k;
@ -3740,46 +3740,43 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
bool earlystopped = false;
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
if(!early_abort)
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
{
printf("\n(EOS token triggered! ID:%d)",id);
if(allow_regular_prints)
{
printf("\n(EOS token triggered! ID:%d)",id);
}
early_abort = true;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
}
remaining_tokens = 0;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
earlystopped = true;
}
if(!earlystopped)
if(!early_abort)
{
for (const auto &matched : special_stop_sequence)
{
if(id==matched)
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
{
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
}
remaining_tokens = 0;
early_abort = true;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
earlystopped = true;
break;
}
}
}
if(!earlystopped)
if(!early_abort)
{
for (const auto &matched : stop_sequence)
{
if (concat_output.find(matched) != std::string::npos)
{
stopper_unused_tokens = remaining_tokens;
remaining_tokens = 0;
early_abort = true;
if(allow_regular_prints)
{
auto match_clean = matched;
@ -3787,7 +3784,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
}
last_stop_reason = stop_reason::CUSTOM_STOPPER;
earlystopped = true;
break;
}
}
@ -3807,7 +3803,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fflush(stdout);
}
else
else if(!early_abort) //do not ingest prompt if aborted!
{
// some user input remains from prompt or interaction, forward it to processing
while ((int)embd_inp.size() > input_consumed)
@ -3926,10 +3922,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
time2 = timer_check();
float pt1 = (time1*1000.0/(embd_inp.size()==0?1:embd_inp.size()));
float ts1 = (1000.0/pt1);
int realnpredict = kcpp_data->n_predict-stopper_unused_tokens;
float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict));
int realnpredict = kcpp_data->n_predict-remaining_tokens;
float pt2 = (time2*1000.0/(realnpredict<=0?1:realnpredict));
float ts2 = (1000.0/pt2);
float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2));
float tokens_per_second = (realnpredict <= 0 ? 0 : realnpredict / (time1 + time2));
printf("\n[%s] CtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",get_timestamp_str().c_str(),(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
fflush(stdout);
output.status = 1;