mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip antislop
This commit is contained in:
parent
740c5e01cb
commit
65f3c68399
3 changed files with 56 additions and 1 deletions
2
expose.h
2
expose.h
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
const int stop_token_max = 24;
|
const int stop_token_max = 24;
|
||||||
const int ban_token_max = 16;
|
const int ban_token_max = 16;
|
||||||
|
const int ban_phrase_max = 16;
|
||||||
const int tensor_split_max = 16;
|
const int tensor_split_max = 16;
|
||||||
const int logit_bias_max = 24;
|
const int logit_bias_max = 24;
|
||||||
const int dry_seq_break_max = 24;
|
const int dry_seq_break_max = 24;
|
||||||
|
@ -106,6 +107,7 @@ struct generation_inputs
|
||||||
const float smoothing_factor = 0.0f;
|
const float smoothing_factor = 0.0f;
|
||||||
const logit_bias logit_biases[logit_bias_max] = {};
|
const logit_bias logit_biases[logit_bias_max] = {};
|
||||||
const char * banned_tokens[ban_token_max] = {};
|
const char * banned_tokens[ban_token_max] = {};
|
||||||
|
const char * banned_phrases[ban_phrase_max] = {};
|
||||||
};
|
};
|
||||||
struct generation_outputs
|
struct generation_outputs
|
||||||
{
|
{
|
||||||
|
|
|
@ -108,6 +108,7 @@ static std::vector<std::string> stop_sequence;
|
||||||
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
||||||
static std::vector<std::string> banned_tokens;
|
static std::vector<std::string> banned_tokens;
|
||||||
static std::vector<int> banned_token_ids;
|
static std::vector<int> banned_token_ids;
|
||||||
|
static std::vector<std::string> banned_phrases;
|
||||||
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
|
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
|
||||||
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
|
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
|
||||||
static std::unordered_map<gpt_vocab::id, int> dry_max_token_repeat;
|
static std::unordered_map<gpt_vocab::id, int> dry_max_token_repeat;
|
||||||
|
@ -2530,6 +2531,30 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//antislop phrase banning
|
||||||
|
banned_phrases.clear();
|
||||||
|
delayed_generated_tokens_limit = 0;
|
||||||
|
for(int x=0;x<ban_phrase_max;++x)
|
||||||
|
{
|
||||||
|
std::string word = inputs.banned_phrases[x];
|
||||||
|
if(word!="")
|
||||||
|
{
|
||||||
|
std::vector<int> toks;
|
||||||
|
TokenizeString(word, toks, file_format, false);
|
||||||
|
int tokcount = toks.size();
|
||||||
|
if(tokcount>0)
|
||||||
|
{
|
||||||
|
tokcount += 2; //add some extra buffer
|
||||||
|
}
|
||||||
|
delayed_generated_tokens_limit = (tokcount>delayed_generated_tokens_limit?tokcount:delayed_generated_tokens_limit);
|
||||||
|
banned_phrases.push_back(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(debugmode==1 && banned_phrases.size()>0)
|
||||||
|
{
|
||||||
|
printf("\nBanned a total of %zu phrases, with max token count of %d.\n",banned_phrases.size(),delayed_generated_tokens_limit);
|
||||||
|
}
|
||||||
|
|
||||||
logit_biases.clear();
|
logit_biases.clear();
|
||||||
for(int x=0;x<logit_bias_max;++x)
|
for(int x=0;x<logit_bias_max;++x)
|
||||||
{
|
{
|
||||||
|
@ -3234,6 +3259,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
printf("]\n");
|
printf("]\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//anti slop detection
|
||||||
|
// for (const auto &matched : stop_sequence)
|
||||||
|
// {
|
||||||
|
// if (concat_output.find(matched) != std::string::npos)
|
||||||
|
// {
|
||||||
|
// stopper_unused_tokens = remaining_tokens;
|
||||||
|
// remaining_tokens = 0;
|
||||||
|
// if(allow_regular_prints)
|
||||||
|
// {
|
||||||
|
// auto match_clean = matched;
|
||||||
|
// replace_all(match_clean, "\n", "\\n");
|
||||||
|
// printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
|
||||||
|
// }
|
||||||
|
// last_stop_reason = stop_reason::CUSTOM_STOPPER;
|
||||||
|
// earlystopped = true;
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
bool earlystopped = false;
|
bool earlystopped = false;
|
||||||
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
||||||
{
|
{
|
||||||
|
|
11
koboldcpp.py
11
koboldcpp.py
|
@ -21,6 +21,7 @@ from datetime import datetime, timezone
|
||||||
sampler_order_max = 7
|
sampler_order_max = 7
|
||||||
stop_token_max = 24
|
stop_token_max = 24
|
||||||
ban_token_max = 16
|
ban_token_max = 16
|
||||||
|
ban_phrase_max = 16
|
||||||
tensor_split_max = 16
|
tensor_split_max = 16
|
||||||
logit_bias_max = 24
|
logit_bias_max = 24
|
||||||
dry_seq_break_max = 24
|
dry_seq_break_max = 24
|
||||||
|
@ -171,7 +172,8 @@ class generation_inputs(ctypes.Structure):
|
||||||
("dynatemp_exponent", ctypes.c_float),
|
("dynatemp_exponent", ctypes.c_float),
|
||||||
("smoothing_factor", ctypes.c_float),
|
("smoothing_factor", ctypes.c_float),
|
||||||
("logit_biases", logit_bias * logit_bias_max),
|
("logit_biases", logit_bias * logit_bias_max),
|
||||||
("banned_tokens", ctypes.c_char_p * ban_token_max)]
|
("banned_tokens", ctypes.c_char_p * ban_token_max),
|
||||||
|
("banned_phrases", ctypes.c_char_p * ban_phrase_max)]
|
||||||
|
|
||||||
class generation_outputs(ctypes.Structure):
|
class generation_outputs(ctypes.Structure):
|
||||||
_fields_ = [("status", ctypes.c_int),
|
_fields_ = [("status", ctypes.c_int),
|
||||||
|
@ -910,6 +912,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
logit_biases = genparams.get('logit_bias', {})
|
logit_biases = genparams.get('logit_bias', {})
|
||||||
render_special = genparams.get('render_special', False)
|
render_special = genparams.get('render_special', False)
|
||||||
banned_tokens = genparams.get('banned_tokens', [])
|
banned_tokens = genparams.get('banned_tokens', [])
|
||||||
|
banned_phrases = genparams.get('banned_phrases', [])
|
||||||
bypass_eos_token = genparams.get('bypass_eos', False)
|
bypass_eos_token = genparams.get('bypass_eos', False)
|
||||||
|
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
|
@ -1028,6 +1031,12 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
else:
|
else:
|
||||||
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
||||||
|
|
||||||
|
for n in range(ban_phrase_max):
|
||||||
|
if not banned_phrases or n >= len(banned_phrases):
|
||||||
|
inputs.banned_phrases[n] = "".encode("UTF-8")
|
||||||
|
else:
|
||||||
|
inputs.banned_phrases[n] = banned_phrases[n].encode("UTF-8")
|
||||||
|
|
||||||
currentusergenkey = genkey
|
currentusergenkey = genkey
|
||||||
totalgens += 1
|
totalgens += 1
|
||||||
#early exit if aborted
|
#early exit if aborted
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue