wip antislop

This commit is contained in:
Concedo 2024-10-07 20:19:22 +08:00
parent 740c5e01cb
commit 65f3c68399
3 changed files with 56 additions and 1 deletions

View file

@ -21,6 +21,7 @@ from datetime import datetime, timezone
sampler_order_max = 7
stop_token_max = 24
ban_token_max = 16
ban_phrase_max = 16
tensor_split_max = 16
logit_bias_max = 24
dry_seq_break_max = 24
@ -171,7 +172,8 @@ class generation_inputs(ctypes.Structure):
("dynatemp_exponent", ctypes.c_float),
("smoothing_factor", ctypes.c_float),
("logit_biases", logit_bias * logit_bias_max),
("banned_tokens", ctypes.c_char_p * ban_token_max)]
("banned_tokens", ctypes.c_char_p * ban_token_max),
("banned_phrases", ctypes.c_char_p * ban_phrase_max)]
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@ -910,6 +912,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
logit_biases = genparams.get('logit_bias', {})
render_special = genparams.get('render_special', False)
banned_tokens = genparams.get('banned_tokens', [])
banned_phrases = genparams.get('banned_phrases', [])
bypass_eos_token = genparams.get('bypass_eos', False)
inputs = generation_inputs()
@ -1028,6 +1031,12 @@ def generate(genparams, is_quiet=False, stream_flag=False):
else:
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
for n in range(ban_phrase_max):
if not banned_phrases or n >= len(banned_phrases):
inputs.banned_phrases[n] = "".encode("UTF-8")
else:
inputs.banned_phrases[n] = banned_phrases[n].encode("UTF-8")
currentusergenkey = genkey
totalgens += 1
#early exit if aborted