diff --git a/koboldcpp.py b/koboldcpp.py index 9c2eb574c..fa515aacf 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -912,6 +912,11 @@ def generate(genparams, is_quiet=False, stream_flag=False): banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name banned_tokens = genparams.get('banned_tokens', banned_strings) bypass_eos_token = genparams.get('bypass_eos', False) + custom_token_bans = genparams.get('custom_token_bans', '') + + for tok in custom_token_bans.split(','): + if tok.isdigit(): + logit_biases[tok] = -999 inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8")