wip antislop

This commit is contained in:
Concedo 2024-10-07 20:19:22 +08:00
parent 740c5e01cb
commit 65f3c68399
3 changed files with 56 additions and 1 deletions

View file

@ -108,6 +108,7 @@ static std::vector<std::string> stop_sequence;
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
static std::vector<std::string> banned_tokens;
static std::vector<int> banned_token_ids;
static std::vector<std::string> banned_phrases;
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
static std::unordered_map<gpt_vocab::id, int> dry_max_token_repeat;
@ -2530,6 +2531,30 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
//antislop phrase banning
banned_phrases.clear();
delayed_generated_tokens_limit = 0;
for(int x=0;x<ban_phrase_max;++x)
{
std::string word = inputs.banned_phrases[x];
if(word!="")
{
std::vector<int> toks;
TokenizeString(word, toks, file_format, false);
int tokcount = toks.size();
if(tokcount>0)
{
tokcount += 2; //add some extra buffer
}
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
banned_phrases.push_back(word);
}
}
if(debugmode==1 && banned_phrases.size()>0)
{
printf("\nBanned a total of %zu phrases, with max token count of %d.\n",banned_phrases.size(),delayed_generated_tokens_limit);
}
logit_biases.clear();
for(int x=0;x<logit_bias_max;++x)
{
@ -3234,6 +3259,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("]\n");
}
//anti slop detection
// for (const auto &matched : stop_sequence)
// {
// if (concat_output.find(matched) != std::string::npos)
// {
// stopper_unused_tokens = remaining_tokens;
// remaining_tokens = 0;
// if(allow_regular_prints)
// {
// auto match_clean = matched;
// replace_all(match_clean, "\n", "\\n");
// printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
// }
// last_stop_reason = stop_reason::CUSTOM_STOPPER;
// earlystopped = true;
// break;
// }
// }
bool earlystopped = false;
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
{