wip anti slop sampler

This commit is contained in:
Concedo 2024-10-09 13:34:47 +08:00
parent f78f8d3d45
commit 36e9bac98f

View file

@ -18,6 +18,8 @@
#include <map>
#include <cstdint>
#include <string>
#include <cctype>
#include <locale>
//for easier compilation
//concat source files into one file for compilation purposes
@ -406,6 +408,19 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
}
}
// Function to convert a UTF-8 encoded string to lowercase
static std::string toLowerCase(const std::string& str) {
std::string result;
std::locale loc;
for (char ch : str) {
result += std::tolower(ch, loc); // Use locale-aware tolower
}
return result;
}
void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
{
if(amount_rewind<=0 || current_context_tokens.size()==0)
@ -2544,7 +2559,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
int tokcount = toks.size();
if(tokcount>0)
{
tokcount += 1; //add some extra buffer
tokcount += 3; //add some extra buffer
}
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
banned_phrases.push_back(word);
@ -3260,18 +3275,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
//anti slop detection
std::string scanstr = "";
for(int i=0;i<delayed_generated_tokens.size();++i)
{
scanstr += delayed_generated_tokens[i];
}
scanstr = toLowerCase(scanstr);
for (const auto &matched : banned_phrases)
{
if (concat_output.find(matched) != std::string::npos)
std::string matched_lower = toLowerCase(matched);
if (scanstr.find(matched_lower) != std::string::npos)
{
std::vector<int> toks;
TokenizeString(matched, toks, file_format, false);
int tokcount = toks.size();
//find the position in the string that contains all necessary tokens
std::string checkstr = "";
int rewind_amt = 0;
for(int i=delayed_generated_tokens.size()-1;i>=0;--i)
{
checkstr = delayed_generated_tokens[i] + checkstr;
++rewind_amt;
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
{
break;
}
}
delayed_generated_tokens.resize(delayed_generated_tokens.size()-rewind_amt);
ContextRewind(embd,current_context_tokens,n_past,last_n_tokens,rewind_amt);
if(allow_regular_prints)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),tokcount);
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),rewind_amt);
}
break;
}