diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 574acc49e..02888a1c7 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -18,6 +18,8 @@ #include #include #include +#include +#include //for easier compilation //concat source files into one file for compilation purposes @@ -406,6 +408,19 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_ } } +// Function to convert a UTF-8 encoded string to lowercase +static std::string toLowerCase(const std::string& str) { + std::string result; + std::locale loc; + + for (char ch : str) { + result += std::tolower(ch, loc); // Use locale-aware tolower + } + + return result; +} + + void ContextRewind(std::vector &embd, std::vector ¤t_context_tokens, int &n_past, std::vector &last_n_tokens, const int amount_rewind) { if(amount_rewind<=0 || current_context_tokens.size()==0) @@ -2544,7 +2559,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) int tokcount = toks.size(); if(tokcount>0) { - tokcount += 1; //add some extra buffer + tokcount += 3; //add some extra buffer } delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit); banned_phrases.push_back(word); @@ -3260,18 +3275,37 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } //anti slop detection + std::string scanstr = ""; + for(int i=0;i toks; - TokenizeString(matched, toks, file_format, false); - int tokcount = toks.size(); + //find the position in the string that contains all necessary tokens + std::string checkstr = ""; + int rewind_amt = 0; + for(int i=delayed_generated_tokens.size()-1;i>=0;--i) + { + checkstr = delayed_generated_tokens[i] + checkstr; + ++rewind_amt; + if (toLowerCase(checkstr).find(matched_lower) != std::string::npos) + { + break; + } + } + delayed_generated_tokens.resize(delayed_generated_tokens.size()-rewind_amt); + ContextRewind(embd,current_context_tokens,n_past,last_n_tokens,rewind_amt); + if(allow_regular_prints) { auto match_clean = matched; replace_all(match_clean, "\n", "\\n"); - printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),tokcount); + printf("\n(Banned Phrase Detected: %s - Rewinding %d tokens)\n", match_clean.c_str(),rewind_amt); } break; }