mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-16 11:59:42 +00:00
fixed fast forwarding context corruption after abort during prompt processing
This commit is contained in:
parent
a11bba5893
commit
00a686fc72
1 changed files with 22 additions and 26 deletions
|
@ -120,7 +120,7 @@ static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
|
|||
static std::unordered_map<gpt_vocab::id, int> dry_max_token_repeat;
|
||||
static std::vector<TopPicksData> top_picks_history;
|
||||
static int remaining_tokens = 0;
|
||||
static int stopper_unused_tokens = 0;
|
||||
static bool early_abort = false;
|
||||
static std::mutex concat_output_mtx;
|
||||
static std::string concat_output = "";
|
||||
static std::string concat_output_reader_copy_poll = ""; //for streaming
|
||||
|
@ -2666,8 +2666,7 @@ bool gpttype_generate_abort()
|
|||
{
|
||||
printf("\nWarning: KCPP text generation not initialized!\n");
|
||||
}
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
remaining_tokens = 0;
|
||||
early_abort = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -2801,6 +2800,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
dry_sequence_breakers.clear();
|
||||
dry_max_token_repeat.clear();
|
||||
top_picks_history.clear();
|
||||
early_abort = false;
|
||||
|
||||
double time0 = 0, time1 = 0, time2 = 0;
|
||||
timer_start();
|
||||
|
@ -3280,7 +3280,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
current_context_tokens.resize(n_past);
|
||||
|
||||
remaining_tokens = kcpp_data->n_predict;
|
||||
stopper_unused_tokens = 0;
|
||||
int input_consumed = 0;
|
||||
std::mt19937 rng(kcpp_data->seed);
|
||||
|
||||
|
@ -3368,7 +3367,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("%s\n\n", RemoveBell(outstr).c_str());
|
||||
}
|
||||
|
||||
while (remaining_tokens > 0)
|
||||
while (remaining_tokens > 0 && !early_abort)
|
||||
{
|
||||
gpt_vocab::id id = 0;
|
||||
// predict
|
||||
|
@ -3492,7 +3491,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
|
||||
n_past += embd.size();
|
||||
embd.clear();
|
||||
if ((int)embd_inp.size() <= input_consumed)
|
||||
|
||||
if (!early_abort && (int)embd_inp.size() <= input_consumed) //if decoding was aborted, DO NOT perform any sampling
|
||||
{
|
||||
// out of user input, sample next token
|
||||
const float top_k = kcpp_data->top_k;
|
||||
|
@ -3740,46 +3740,43 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
|
||||
bool earlystopped = false;
|
||||
if(!early_abort)
|
||||
{
|
||||
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
printf("\n(EOS token triggered! ID:%d)",id);
|
||||
}
|
||||
remaining_tokens = 0;
|
||||
early_abort = true;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
earlystopped = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(!earlystopped)
|
||||
if(!early_abort)
|
||||
{
|
||||
for (const auto &matched : special_stop_sequence)
|
||||
{
|
||||
if(id==matched)
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
|
||||
}
|
||||
remaining_tokens = 0;
|
||||
early_abort = true;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
earlystopped = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(!earlystopped)
|
||||
if(!early_abort)
|
||||
{
|
||||
for (const auto &matched : stop_sequence)
|
||||
{
|
||||
if (concat_output.find(matched) != std::string::npos)
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
remaining_tokens = 0;
|
||||
early_abort = true;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
auto match_clean = matched;
|
||||
|
@ -3787,7 +3784,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
|
||||
}
|
||||
last_stop_reason = stop_reason::CUSTOM_STOPPER;
|
||||
earlystopped = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -3807,7 +3803,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
|
||||
fflush(stdout);
|
||||
}
|
||||
else
|
||||
else if(!early_abort) //do not ingest prompt if aborted!
|
||||
{
|
||||
// some user input remains from prompt or interaction, forward it to processing
|
||||
while ((int)embd_inp.size() > input_consumed)
|
||||
|
@ -3926,10 +3922,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
time2 = timer_check();
|
||||
float pt1 = (time1*1000.0/(embd_inp.size()==0?1:embd_inp.size()));
|
||||
float ts1 = (1000.0/pt1);
|
||||
int realnpredict = kcpp_data->n_predict-stopper_unused_tokens;
|
||||
float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict));
|
||||
int realnpredict = kcpp_data->n_predict-remaining_tokens;
|
||||
float pt2 = (time2*1000.0/(realnpredict<=0?1:realnpredict));
|
||||
float ts2 = (1000.0/pt2);
|
||||
float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2));
|
||||
float tokens_per_second = (realnpredict <= 0 ? 0 : realnpredict / (time1 + time2));
|
||||
printf("\n[%s] CtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",get_timestamp_str().c_str(),(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
|
||||
fflush(stdout);
|
||||
output.status = 1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue