mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
Dynamic sizes for sequences (#1157)
* Dynamic sizes for sequences * cleanup PR - move all dynamic fields to end of payload, ensure correct null handling to match existing behavior, add anti abuse limit of max 512 for dynamic fields * adjust anti abuse limits --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
7f76425450
commit
8bb220329c
3 changed files with 74 additions and 56 deletions
98
koboldcpp.py
98
koboldcpp.py
|
@ -19,15 +19,17 @@ from datetime import datetime, timezone
|
|||
|
||||
# constants
|
||||
sampler_order_max = 7
|
||||
stop_token_max = 32
|
||||
ban_token_max = 64
|
||||
tensor_split_max = 16
|
||||
logit_bias_max = 32
|
||||
dry_seq_break_max = 24
|
||||
images_max = 4
|
||||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
|
||||
# abuse prevention
|
||||
stop_token_max = 512
|
||||
ban_token_max = 1024
|
||||
logit_bias_max = 1024
|
||||
dry_seq_break_max = 256
|
||||
|
||||
# global vars
|
||||
handle = None
|
||||
friendlymodelname = "inactive"
|
||||
|
@ -150,11 +152,6 @@ class generation_inputs(ctypes.Structure):
|
|||
("mirostat", ctypes.c_int),
|
||||
("mirostat_tau", ctypes.c_float),
|
||||
("mirostat_eta", ctypes.c_float),
|
||||
("dry_multiplier", ctypes.c_float),
|
||||
("dry_base", ctypes.c_float),
|
||||
("dry_allowed_length", ctypes.c_int),
|
||||
("dry_penalty_last_n", ctypes.c_int),
|
||||
("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max),
|
||||
("xtc_threshold", ctypes.c_float),
|
||||
("xtc_probability", ctypes.c_float),
|
||||
("sampler_order", ctypes.c_int * sampler_order_max),
|
||||
|
@ -162,7 +159,6 @@ class generation_inputs(ctypes.Structure):
|
|||
("allow_eos_token", ctypes.c_bool),
|
||||
("bypass_eos_token", ctypes.c_bool),
|
||||
("render_special", ctypes.c_bool),
|
||||
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
||||
("stream_sse", ctypes.c_bool),
|
||||
("grammar", ctypes.c_char_p),
|
||||
("grammar_retain_state", ctypes.c_bool),
|
||||
|
@ -170,8 +166,18 @@ class generation_inputs(ctypes.Structure):
|
|||
("dynatemp_range", ctypes.c_float),
|
||||
("dynatemp_exponent", ctypes.c_float),
|
||||
("smoothing_factor", ctypes.c_float),
|
||||
("logit_biases", logit_bias * logit_bias_max),
|
||||
("banned_tokens", ctypes.c_char_p * ban_token_max)]
|
||||
("dry_multiplier", ctypes.c_float),
|
||||
("dry_base", ctypes.c_float),
|
||||
("dry_allowed_length", ctypes.c_int),
|
||||
("dry_penalty_last_n", ctypes.c_int),
|
||||
("dry_sequence_breakers_len", ctypes.c_int),
|
||||
("dry_sequence_breakers", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("stop_sequence_len", ctypes.c_int),
|
||||
("stop_sequence", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("logit_biases_len", ctypes.c_int),
|
||||
("logit_biases", ctypes.POINTER(logit_bias)),
|
||||
("banned_tokens_len", ctypes.c_int),
|
||||
("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
|
||||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
@ -982,11 +988,16 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
except ValueError as e:
|
||||
print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e))
|
||||
dry_sequence_breakers = []
|
||||
for n in range(dry_seq_break_max):
|
||||
if dry_multiplier > 0 and n < len(dry_sequence_breakers):
|
||||
inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8")
|
||||
else:
|
||||
inputs.dry_sequence_breakers[n] = "".encode("UTF-8")
|
||||
|
||||
if dry_multiplier <= 0 or dry_sequence_breakers is None: # prevent explicitly set to None, retain old behavior
|
||||
dry_sequence_breakers = []
|
||||
|
||||
dry_sequence_breakers = dry_sequence_breakers[:dry_seq_break_max]
|
||||
inputs.dry_sequence_breakers_len = len(dry_sequence_breakers)
|
||||
inputs.dry_sequence_breakers = (ctypes.c_char_p * inputs.dry_sequence_breakers_len)()
|
||||
|
||||
for n, breaker in enumerate(dry_sequence_breakers):
|
||||
inputs.dry_sequence_breakers[n] = breaker.encode("UTF-8")
|
||||
|
||||
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
||||
try:
|
||||
|
@ -1000,13 +1011,18 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
except TypeError as e:
|
||||
print("ERROR: sampler_order must be a list of integers: " + str(e))
|
||||
inputs.seed = seed
|
||||
for n in range(stop_token_max):
|
||||
if not stop_sequence or n >= len(stop_sequence):
|
||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
elif stop_sequence[n]==None:
|
||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
|
||||
if stop_sequence is None:
|
||||
stop_sequence = []
|
||||
stop_sequence = stop_sequence[:stop_token_max]
|
||||
inputs.stop_sequence_len = len(stop_sequence)
|
||||
inputs.stop_sequence = (ctypes.c_char_p * inputs.stop_sequence_len)()
|
||||
|
||||
for n, sequence in enumerate(stop_sequence):
|
||||
if sequence:
|
||||
inputs.stop_sequence[n] = sequence.encode("UTF-8")
|
||||
else:
|
||||
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
|
||||
bias_list = []
|
||||
try:
|
||||
|
@ -1015,25 +1031,27 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
except Exception as ex:
|
||||
print(f"Logit bias dictionary is invalid: {ex}")
|
||||
|
||||
for n in range(logit_bias_max):
|
||||
if n >= len(bias_list):
|
||||
bias_list = bias_list[:logit_bias_max]
|
||||
inputs.logit_biases_len = len(bias_list)
|
||||
inputs.logit_biases = (logit_bias * inputs.logit_biases_len)()
|
||||
for n, lb in enumerate(bias_list):
|
||||
try:
|
||||
t_id = int(lb['key'])
|
||||
bias = float(lb['value'])
|
||||
t_id = -1 if t_id < 0 else t_id
|
||||
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
||||
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
||||
except Exception as ex:
|
||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
else:
|
||||
try:
|
||||
t_id = int(bias_list[n]['key'])
|
||||
bias = float(bias_list[n]['value'])
|
||||
t_id = -1 if t_id < 0 else t_id
|
||||
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
||||
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
||||
except Exception as ex:
|
||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
print(f"Skipped unparsable logit bias:{ex}")
|
||||
print(f"Skipped unparsable logit bias:{ex}")
|
||||
|
||||
for n in range(ban_token_max):
|
||||
if not banned_tokens or n >= len(banned_tokens):
|
||||
inputs.banned_tokens[n] = "".encode("UTF-8")
|
||||
else:
|
||||
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
||||
if banned_tokens is None:
|
||||
banned_tokens = []
|
||||
banned_tokens = banned_tokens[:ban_token_max]
|
||||
inputs.banned_tokens_len = len(banned_tokens)
|
||||
inputs.banned_tokens = (ctypes.c_char_p * inputs.banned_tokens_len)()
|
||||
for n, tok in enumerate(banned_tokens):
|
||||
inputs.banned_tokens[n] = tok.encode("UTF-8")
|
||||
|
||||
currentusergenkey = genkey
|
||||
totalgens += 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue