antislop sampler working

This commit is contained in:
Concedo 2024-10-09 16:33:04 +08:00
parent 36e9bac98f
commit 9b614d46bd
18 changed files with 54 additions and 8881 deletions

View file

@ -125,7 +125,7 @@ static std::vector<logit_bias> logit_biases;
static int delayed_generated_tokens_limit = 0;
std::deque<std::string> delayed_generated_tokens; //for use with antislop sampling
static std::map<int,std::vector<int>> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index
inline bool IsNanCheck(float f)
{
@ -2549,6 +2549,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
//antislop phrase banning
banned_phrases.clear();
delayed_generated_tokens_limit = 0;
antislop_banned_token_ids.clear();
for(int x=0;x<ban_phrase_max;++x)
{
std::string word = inputs.banned_phrases[x];
@ -3212,6 +3213,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
//handle temp bans from antislop
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
std::vector<int>& bans = antislop_banned_token_ids[n_past];
print_tok_vec_str(bans);
for(int t=0;t<bans.size();++t)
{
logitsPtr[bans[t]]=lowestLogit;
}
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
@ -3275,39 +3286,55 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
//anti slop detection
std::string scanstr = "";
for(int i=0;i<delayed_generated_tokens.size();++i)
if (banned_phrases.size() > 0)
{
scanstr += delayed_generated_tokens[i];
}
scanstr = toLowerCase(scanstr);
for (const auto &matched : banned_phrases)
{
std::string matched_lower = toLowerCase(matched);
if (scanstr.find(matched_lower) != std::string::npos)
std::string scanstr = "";
for (int i = 0; i < delayed_generated_tokens.size(); ++i)
{
//find the position in the string that contains all necessary tokens
std::string checkstr = "";
int rewind_amt = 0;
for(int i=delayed_generated_tokens.size()-1;i>=0;--i)
scanstr += delayed_generated_tokens[i];
}
scanstr = toLowerCase(scanstr);
for (const auto &matched : banned_phrases)
{
std::string matched_lower = toLowerCase(matched);
if (scanstr.find(matched_lower) != std::string::npos)
{
checkstr = delayed_generated_tokens[i] + checkstr;
++rewind_amt;
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
//find the position in the string that contains all necessary tokens
std::string checkstr = "";
int rewind_amt = 0;
for (int i = delayed_generated_tokens.size() - 1; i >= 0; --i)
{
checkstr = delayed_generated_tokens[i] + checkstr;
++rewind_amt;
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
{
break;
}
}
if (rewind_amt > 0 && (current_context_tokens.size() - rewind_amt) > 0)
{
int last_tok = current_context_tokens[current_context_tokens.size() - rewind_amt];
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
// Check if the key exists
int banindex = n_past+1;
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
antislop_banned_token_ids[banindex] = std::vector<int>();
}
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
current_ids.push_back(last_tok);
if (allow_regular_prints && debugmode == 1)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
}
break;
}
}
delayed_generated_tokens.resize(delayed_generated_tokens.size()-rewind_amt);
ContextRewind(embd,current_context_tokens,n_past,last_n_tokens,rewind_amt);
if(allow_regular_prints)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),rewind_amt);
}
break;
}
}