Add custom_token_bans (#1153)

This commit is contained in:
Maya 2024-10-10 18:45:07 +03:00 committed by GitHub
parent a3b104a422
commit 3dab63887f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -912,6 +912,11 @@ def generate(genparams, is_quiet=False, stream_flag=False):
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
banned_tokens = genparams.get('banned_tokens', banned_strings)
bypass_eos_token = genparams.get('bypass_eos', False)
custom_token_bans = genparams.get('custom_token_bans', '')
for tok in custom_token_bans.split(','):
if tok.isdigit():
logit_biases[tok] = -999
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")