Dynamic sizes for sequences (#1157)

* Dynamic sizes for sequences

* cleanup PR - move all dynamic fields to end of payload, ensure correct null handling to match existing behavior, add anti abuse limit of max 512 for dynamic fields

* adjust anti abuse limits

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
Maya 2024-10-16 18:55:11 +03:00 committed by GitHub
parent 7f76425450
commit 8bb220329c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 74 additions and 56 deletions

View file

@ -19,15 +19,17 @@ from datetime import datetime, timezone
# constants
sampler_order_max = 7
stop_token_max = 32
ban_token_max = 64
tensor_split_max = 16
logit_bias_max = 32
dry_seq_break_max = 24
images_max = 4
bias_min_value = -100.0
bias_max_value = 100.0
# abuse prevention
stop_token_max = 512
ban_token_max = 1024
logit_bias_max = 1024
dry_seq_break_max = 256
# global vars
handle = None
friendlymodelname = "inactive"
@ -150,11 +152,6 @@ class generation_inputs(ctypes.Structure):
("mirostat", ctypes.c_int),
("mirostat_tau", ctypes.c_float),
("mirostat_eta", ctypes.c_float),
("dry_multiplier", ctypes.c_float),
("dry_base", ctypes.c_float),
("dry_allowed_length", ctypes.c_int),
("dry_penalty_last_n", ctypes.c_int),
("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max),
("xtc_threshold", ctypes.c_float),
("xtc_probability", ctypes.c_float),
("sampler_order", ctypes.c_int * sampler_order_max),
@ -162,7 +159,6 @@ class generation_inputs(ctypes.Structure):
("allow_eos_token", ctypes.c_bool),
("bypass_eos_token", ctypes.c_bool),
("render_special", ctypes.c_bool),
("stop_sequence", ctypes.c_char_p * stop_token_max),
("stream_sse", ctypes.c_bool),
("grammar", ctypes.c_char_p),
("grammar_retain_state", ctypes.c_bool),
@ -170,8 +166,18 @@ class generation_inputs(ctypes.Structure):
("dynatemp_range", ctypes.c_float),
("dynatemp_exponent", ctypes.c_float),
("smoothing_factor", ctypes.c_float),
("logit_biases", logit_bias * logit_bias_max),
("banned_tokens", ctypes.c_char_p * ban_token_max)]
("dry_multiplier", ctypes.c_float),
("dry_base", ctypes.c_float),
("dry_allowed_length", ctypes.c_int),
("dry_penalty_last_n", ctypes.c_int),
("dry_sequence_breakers_len", ctypes.c_int),
("dry_sequence_breakers", ctypes.POINTER(ctypes.c_char_p)),
("stop_sequence_len", ctypes.c_int),
("stop_sequence", ctypes.POINTER(ctypes.c_char_p)),
("logit_biases_len", ctypes.c_int),
("logit_biases", ctypes.POINTER(logit_bias)),
("banned_tokens_len", ctypes.c_int),
("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@ -982,11 +988,16 @@ def generate(genparams, is_quiet=False, stream_flag=False):
except ValueError as e:
print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e))
dry_sequence_breakers = []
for n in range(dry_seq_break_max):
if dry_multiplier > 0 and n < len(dry_sequence_breakers):
inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8")
else:
inputs.dry_sequence_breakers[n] = "".encode("UTF-8")
if dry_multiplier <= 0 or dry_sequence_breakers is None: # prevent explicitly set to None, retain old behavior
dry_sequence_breakers = []
dry_sequence_breakers = dry_sequence_breakers[:dry_seq_break_max]
inputs.dry_sequence_breakers_len = len(dry_sequence_breakers)
inputs.dry_sequence_breakers = (ctypes.c_char_p * inputs.dry_sequence_breakers_len)()
for n, breaker in enumerate(dry_sequence_breakers):
inputs.dry_sequence_breakers[n] = breaker.encode("UTF-8")
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
try:
@ -1000,13 +1011,18 @@ def generate(genparams, is_quiet=False, stream_flag=False):
except TypeError as e:
print("ERROR: sampler_order must be a list of integers: " + str(e))
inputs.seed = seed
for n in range(stop_token_max):
if not stop_sequence or n >= len(stop_sequence):
inputs.stop_sequence[n] = "".encode("UTF-8")
elif stop_sequence[n]==None:
inputs.stop_sequence[n] = "".encode("UTF-8")
if stop_sequence is None:
stop_sequence = []
stop_sequence = stop_sequence[:stop_token_max]
inputs.stop_sequence_len = len(stop_sequence)
inputs.stop_sequence = (ctypes.c_char_p * inputs.stop_sequence_len)()
for n, sequence in enumerate(stop_sequence):
if sequence:
inputs.stop_sequence[n] = sequence.encode("UTF-8")
else:
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
inputs.stop_sequence[n] = "".encode("UTF-8")
bias_list = []
try:
@ -1015,25 +1031,27 @@ def generate(genparams, is_quiet=False, stream_flag=False):
except Exception as ex:
print(f"Logit bias dictionary is invalid: {ex}")
for n in range(logit_bias_max):
if n >= len(bias_list):
bias_list = bias_list[:logit_bias_max]
inputs.logit_biases_len = len(bias_list)
inputs.logit_biases = (logit_bias * inputs.logit_biases_len)()
for n, lb in enumerate(bias_list):
try:
t_id = int(lb['key'])
bias = float(lb['value'])
t_id = -1 if t_id < 0 else t_id
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
inputs.logit_biases[n] = logit_bias(t_id, bias)
except Exception as ex:
inputs.logit_biases[n] = logit_bias(-1, 0.0)
else:
try:
t_id = int(bias_list[n]['key'])
bias = float(bias_list[n]['value'])
t_id = -1 if t_id < 0 else t_id
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
inputs.logit_biases[n] = logit_bias(t_id, bias)
except Exception as ex:
inputs.logit_biases[n] = logit_bias(-1, 0.0)
print(f"Skipped unparsable logit bias:{ex}")
print(f"Skipped unparsable logit bias:{ex}")
for n in range(ban_token_max):
if not banned_tokens or n >= len(banned_tokens):
inputs.banned_tokens[n] = "".encode("UTF-8")
else:
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
if banned_tokens is None:
banned_tokens = []
banned_tokens = banned_tokens[:ban_token_max]
inputs.banned_tokens_len = len(banned_tokens)
inputs.banned_tokens = (ctypes.c_char_p * inputs.banned_tokens_len)()
for n, tok in enumerate(banned_tokens):
inputs.banned_tokens[n] = tok.encode("UTF-8")
currentusergenkey = genkey
totalgens += 1