unify antislop and token bans

This commit is contained in:
Concedo 2024-10-10 18:21:07 +08:00
parent a6bf568fda
commit fe5479f286
4 changed files with 37 additions and 108 deletions

View file

@ -2511,26 +2511,48 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
//handle custom token bans
//handle custom token bans and antislop phrase banning
banned_phrases.clear();
delayed_generated_tokens_limit = 0;
antislop_banned_token_ids.clear();
banned_tokens.clear();
for(int x=0;x<ban_token_max;++x)
{
std::string word = inputs.banned_tokens[x];
word = toLowerCase(word);
if(word!="")
{
banned_tokens.push_back(word);
std::vector<int> toks;
TokenizeString(word, toks, file_format, false);
int tokcount = toks.size();
if(tokcount==0)
{
continue;
}
if(tokcount==1 && word.length()<2) //only use banned tokens for single characters
{
banned_tokens.push_back(word);
}
else
{
tokcount += 3; //add some extra buffer
delayed_generated_tokens_limit = (tokcount > delayed_generated_tokens_limit ? tokcount : delayed_generated_tokens_limit);
banned_phrases.push_back(word);
}
}
}
banned_token_ids.clear();
if(banned_tokens.size()>0)
{
if(debugmode==1)
{
printf("\nBanning %zu token sequences...",banned_tokens.size());
printf("\nBanning %zu single character sequences...",banned_tokens.size());
}
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v,file_format, true);
word = toLowerCase(word);
for(int i=0;i<banned_tokens.size();++i)
{
if (word.find(banned_tokens[i]) != std::string::npos)
@ -2542,30 +2564,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
if(debugmode==1)
{
printf("\nBanned a total of %zu tokens.\n",banned_token_ids.size());
printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size());
}
}
//antislop phrase banning
banned_phrases.clear();
delayed_generated_tokens_limit = 0;
antislop_banned_token_ids.clear();
for(int x=0;x<ban_phrase_max;++x)
{
std::string word = inputs.banned_phrases[x];
if(word!="")
{
std::vector<int> toks;
TokenizeString(word, toks, file_format, false);
int tokcount = toks.size();
if(tokcount>0)
{
tokcount += 3; //add some extra buffer
}
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
banned_phrases.push_back(word);
}
}
if(debugmode==1 && banned_phrases.size()>0)
{
printf("\nBanned a total of %zu phrases, with max token count of %d.\n",banned_phrases.size(),delayed_generated_tokens_limit);