diff --git a/expose.h b/expose.h index aa23b57a2..04d5b42e5 100644 --- a/expose.h +++ b/expose.h @@ -3,6 +3,7 @@ const int stop_token_max = 24; const int ban_token_max = 16; +const int ban_phrase_max = 16; const int tensor_split_max = 16; const int logit_bias_max = 24; const int dry_seq_break_max = 24; @@ -106,6 +107,7 @@ struct generation_inputs const float smoothing_factor = 0.0f; const logit_bias logit_biases[logit_bias_max] = {}; const char * banned_tokens[ban_token_max] = {}; + const char * banned_phrases[ban_phrase_max] = {}; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index ab612db1b..0f81dbd56 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -108,6 +108,7 @@ static std::vector stop_sequence; static std::vector special_stop_sequence; //for stop sequences that don't have a string representation static std::vector banned_tokens; static std::vector banned_token_ids; +static std::vector banned_phrases; static std::unordered_multimap> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token) static std::vector dry_repeat_count; // Indexed as last_n_tokens static std::unordered_map dry_max_token_repeat; @@ -2530,6 +2531,30 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } } + //antislop phrase banning + banned_phrases.clear(); + delayed_generated_tokens_limit = 0; + for(int x=0;x toks; + TokenizeString(word, toks, file_format, false); + int tokcount = toks.size(); + if(tokcount>0) + { + tokcount += 2; //add some extra buffer + } + delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit); + banned_phrases.push_back(word); + } + } + if(debugmode==1 && banned_phrases.size()>0) + { + printf("\nBanned a total of %zu phrases, with max token count of %d.\n",banned_phrases.size(),delayed_generated_tokens_limit); + } + logit_biases.clear(); for(int x=0;x= len(banned_phrases): + inputs.banned_phrases[n] = "".encode("UTF-8") + else: + inputs.banned_phrases[n] = banned_phrases[n].encode("UTF-8") + currentusergenkey = genkey totalgens += 1 #early exit if aborted