mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
wip anti slop sampler
This commit is contained in:
parent
f78f8d3d45
commit
36e9bac98f
1 changed files with 40 additions and 6 deletions
|
@ -18,6 +18,8 @@
|
|||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <locale>
|
||||
|
||||
//for easier compilation
|
||||
//concat source files into one file for compilation purposes
|
||||
|
@ -406,6 +408,19 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
|
|||
}
|
||||
}
|
||||
|
||||
// Function to convert a UTF-8 encoded string to lowercase
|
||||
static std::string toLowerCase(const std::string& str) {
|
||||
std::string result;
|
||||
std::locale loc;
|
||||
|
||||
for (char ch : str) {
|
||||
result += std::tolower(ch, loc); // Use locale-aware tolower
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
|
||||
{
|
||||
if(amount_rewind<=0 || current_context_tokens.size()==0)
|
||||
|
@ -2544,7 +2559,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
int tokcount = toks.size();
|
||||
if(tokcount>0)
|
||||
{
|
||||
tokcount += 1; //add some extra buffer
|
||||
tokcount += 3; //add some extra buffer
|
||||
}
|
||||
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
|
||||
banned_phrases.push_back(word);
|
||||
|
@ -3260,18 +3275,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
|
||||
//anti slop detection
|
||||
std::string scanstr = "";
|
||||
for(int i=0;i<delayed_generated_tokens.size();++i)
|
||||
{
|
||||
scanstr += delayed_generated_tokens[i];
|
||||
}
|
||||
scanstr = toLowerCase(scanstr);
|
||||
for (const auto &matched : banned_phrases)
|
||||
{
|
||||
if (concat_output.find(matched) != std::string::npos)
|
||||
std::string matched_lower = toLowerCase(matched);
|
||||
if (scanstr.find(matched_lower) != std::string::npos)
|
||||
{
|
||||
std::vector<int> toks;
|
||||
TokenizeString(matched, toks, file_format, false);
|
||||
int tokcount = toks.size();
|
||||
//find the position in the string that contains all necessary tokens
|
||||
std::string checkstr = "";
|
||||
int rewind_amt = 0;
|
||||
for(int i=delayed_generated_tokens.size()-1;i>=0;--i)
|
||||
{
|
||||
checkstr = delayed_generated_tokens[i] + checkstr;
|
||||
++rewind_amt;
|
||||
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
delayed_generated_tokens.resize(delayed_generated_tokens.size()-rewind_amt);
|
||||
ContextRewind(embd,current_context_tokens,n_past,last_n_tokens,rewind_amt);
|
||||
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
auto match_clean = matched;
|
||||
replace_all(match_clean, "\n", "\\n");
|
||||
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),tokcount);
|
||||
printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),rewind_amt);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue