wip anti slop sampler

This commit is contained in:
Concedo 2024-10-09 13:34:47 +08:00
parent f78f8d3d45
commit 36e9bac98f

View file

@ -18,6 +18,8 @@
#include <map> #include <map>
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <cctype>
#include <locale>
//for easier compilation //for easier compilation
//concat source files into one file for compilation purposes //concat source files into one file for compilation purposes
@ -406,6 +408,19 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
} }
} }
// Function to convert a UTF-8 encoded string to lowercase
static std::string toLowerCase(const std::string& str) {
std::string result;
std::locale loc;
for (char ch : str) {
result += std::tolower(ch, loc); // Use locale-aware tolower
}
return result;
}
void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind) void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
{ {
if(amount_rewind<=0 || current_context_tokens.size()==0) if(amount_rewind<=0 || current_context_tokens.size()==0)
@ -2544,7 +2559,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
int tokcount = toks.size(); int tokcount = toks.size();
if(tokcount>0) if(tokcount>0)
{ {
tokcount += 1; //add some extra buffer tokcount += 3; //add some extra buffer
} }
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit); delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
banned_phrases.push_back(word); banned_phrases.push_back(word);
@ -3260,18 +3275,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
//anti slop detection //anti slop detection
std::string scanstr = "";
for(int i=0;i<delayed_generated_tokens.size();++i)
{
scanstr += delayed_generated_tokens[i];
}
scanstr = toLowerCase(scanstr);
for (const auto &matched : banned_phrases) for (const auto &matched : banned_phrases)
{ {
if (concat_output.find(matched) != std::string::npos) std::string matched_lower = toLowerCase(matched);
if (scanstr.find(matched_lower) != std::string::npos)
{ {
std::vector<int> toks; //find the position in the string that contains all necessary tokens
TokenizeString(matched, toks, file_format, false); std::string checkstr = "";
int tokcount = toks.size(); int rewind_amt = 0;
for(int i=delayed_generated_tokens.size()-1;i>=0;--i)
{
checkstr = delayed_generated_tokens[i] + checkstr;
++rewind_amt;
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
{
break;
}
}
delayed_generated_tokens.resize(delayed_generated_tokens.size()-rewind_amt);
ContextRewind(embd,current_context_tokens,n_past,last_n_tokens,rewind_amt);
if(allow_regular_prints) if(allow_regular_prints)
{ {
auto match_clean = matched; auto match_clean = matched;
replace_all(match_clean, "\n", "\\n"); replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),tokcount); printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),rewind_amt);
} }
break; break;
} }