mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip anti slop sampler
This commit is contained in:
parent
f78f8d3d45
commit
36e9bac98f
1 changed files with 40 additions and 6 deletions
|
@ -18,6 +18,8 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <cctype>
|
||||||
|
#include <locale>
|
||||||
|
|
||||||
//for easier compilation
|
//for easier compilation
|
||||||
//concat source files into one file for compilation purposes
|
//concat source files into one file for compilation purposes
|
||||||
|
@ -406,6 +408,19 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to convert a UTF-8 encoded string to lowercase
|
||||||
|
static std::string toLowerCase(const std::string& str) {
|
||||||
|
std::string result;
|
||||||
|
std::locale loc;
|
||||||
|
|
||||||
|
for (char ch : str) {
|
||||||
|
result += std::tolower(ch, loc); // Use locale-aware tolower
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
|
void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
|
||||||
{
|
{
|
||||||
if(amount_rewind<=0 || current_context_tokens.size()==0)
|
if(amount_rewind<=0 || current_context_tokens.size()==0)
|
||||||
|
@ -2544,7 +2559,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
int tokcount = toks.size();
|
int tokcount = toks.size();
|
||||||
if(tokcount>0)
|
if(tokcount>0)
|
||||||
{
|
{
|
||||||
tokcount += 1; //add some extra buffer
|
tokcount += 3; //add some extra buffer
|
||||||
}
|
}
|
||||||
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
|
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
|
||||||
banned_phrases.push_back(word);
|
banned_phrases.push_back(word);
|
||||||
|
@ -3260,18 +3275,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
}
|
}
|
||||||
|
|
||||||
//anti slop detection
|
//anti slop detection
|
||||||
|
std::string scanstr = "";
|
||||||
|
for(int i=0;i<delayed_generated_tokens.size();++i)
|
||||||
|
{
|
||||||
|
scanstr += delayed_generated_tokens[i];
|
||||||
|
}
|
||||||
|
scanstr = toLowerCase(scanstr);
|
||||||
for (const auto &matched : banned_phrases)
|
for (const auto &matched : banned_phrases)
|
||||||
{
|
{
|
||||||
if (concat_output.find(matched) != std::string::npos)
|
std::string matched_lower = toLowerCase(matched);
|
||||||
|
if (scanstr.find(matched_lower) != std::string::npos)
|
||||||
{
|
{
|
||||||
std::vector<int> toks;
|
//find the position in the string that contains all necessary tokens
|
||||||
TokenizeString(matched, toks, file_format, false);
|
std::string checkstr = "";
|
||||||
int tokcount = toks.size();
|
int rewind_amt = 0;
|
||||||
|
for(int i=delayed_generated_tokens.size()-1;i>=0;--i)
|
||||||
|
{
|
||||||
|
checkstr = delayed_generated_tokens[i] + checkstr;
|
||||||
|
++rewind_amt;
|
||||||
|
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delayed_generated_tokens.resize(delayed_generated_tokens.size()-rewind_amt);
|
||||||
|
ContextRewind(embd,current_context_tokens,n_past,last_n_tokens,rewind_amt);
|
||||||
|
|
||||||
if(allow_regular_prints)
|
if(allow_regular_prints)
|
||||||
{
|
{
|
||||||
auto match_clean = matched;
|
auto match_clean = matched;
|
||||||
replace_all(match_clean, "\n", "\\n");
|
replace_all(match_clean, "\n", "\\n");
|
||||||
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),tokcount);
|
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),rewind_amt);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue