mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Dynamic sizes for sequences (#1157)
* Dynamic sizes for sequences * cleanup PR - move all dynamic fields to end of payload, ensure correct null handling to match existing behavior, add anti abuse limit of max 512 for dynamic fields * adjust anti abuse limits --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
7f76425450
commit
8bb220329c
3 changed files with 74 additions and 56 deletions
24
expose.h
24
expose.h
|
@ -1,11 +1,7 @@
|
|||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
const int stop_token_max = 32;
|
||||
const int ban_token_max = 64;
|
||||
const int tensor_split_max = 16;
|
||||
const int logit_bias_max = 32;
|
||||
const int dry_seq_break_max = 24;
|
||||
const int images_max = 4;
|
||||
|
||||
// match kobold's sampler list and order
|
||||
|
@ -84,11 +80,6 @@ struct generation_inputs
|
|||
const int mirostat = 0;
|
||||
const float mirostat_eta = 0.0f;
|
||||
const float mirostat_tau = 0.0f;
|
||||
const float dry_multiplier = 0.0f;
|
||||
const float dry_base = 0.0f;
|
||||
const int dry_allowed_length = 0;
|
||||
const int dry_penalty_last_n = 0;
|
||||
const char * dry_sequence_breakers[dry_seq_break_max] = {};
|
||||
const float xtc_threshold = 0.0f;
|
||||
const float xtc_probability = 0.0f;
|
||||
const samplers sampler_order[KCPP_SAMPLER_MAX] = {};
|
||||
|
@ -96,7 +87,6 @@ struct generation_inputs
|
|||
const bool allow_eos_token = false;
|
||||
const bool bypass_eos_token = false;
|
||||
const bool render_special = false;
|
||||
const char * stop_sequence[stop_token_max] = {};
|
||||
const bool stream_sse = false;
|
||||
const char * grammar = nullptr;
|
||||
const bool grammar_retain_state = false;
|
||||
|
@ -104,8 +94,18 @@ struct generation_inputs
|
|||
const float dynatemp_range = 0.0f;
|
||||
const float dynatemp_exponent = 1.0f;
|
||||
const float smoothing_factor = 0.0f;
|
||||
const logit_bias logit_biases[logit_bias_max] = {};
|
||||
const char * banned_tokens[ban_token_max] = {};
|
||||
const float dry_multiplier = 0.0f;
|
||||
const float dry_base = 0.0f;
|
||||
const int dry_allowed_length = 0;
|
||||
const int dry_penalty_last_n = 0;
|
||||
const int dry_sequence_breakers_len = 0;
|
||||
const char ** dry_sequence_breakers = nullptr;
|
||||
const int stop_sequence_len = 0;
|
||||
const char ** stop_sequence = nullptr;
|
||||
const int logit_biases_len = 0;
|
||||
const logit_bias * logit_biases = nullptr;
|
||||
const int banned_tokens_len = 0;
|
||||
const char ** banned_tokens = nullptr;
|
||||
};
|
||||
struct generation_outputs
|
||||
{
|
||||
|
|
|
@ -2488,7 +2488,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
double time0 = 0, time1 = 0, time2 = 0;
|
||||
timer_start();
|
||||
|
||||
for(int x=0;x<stop_token_max;++x)
|
||||
for(int x=0;x<inputs.stop_sequence_len;++x)
|
||||
{
|
||||
std::string stopper = inputs.stop_sequence[x];
|
||||
if(stopper!="")
|
||||
|
@ -2516,7 +2516,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
delayed_generated_tokens_limit = 0;
|
||||
antislop_banned_token_ids.clear();
|
||||
banned_tokens.clear();
|
||||
for(int x=0;x<ban_token_max;++x)
|
||||
for(int x=0;x<inputs.banned_tokens_len;++x)
|
||||
{
|
||||
std::string word = inputs.banned_tokens[x];
|
||||
word = toLowerCase(word);
|
||||
|
@ -2574,7 +2574,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
|
||||
logit_biases.clear();
|
||||
for(int x=0;x<logit_bias_max;++x)
|
||||
for(int x=0;x<inputs.logit_biases_len;++x)
|
||||
{
|
||||
int32_t t_id = inputs.logit_biases[x].token_id;
|
||||
float bias = inputs.logit_biases[x].bias;
|
||||
|
@ -2652,7 +2652,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
|
||||
if (kcpp_data->dry_multiplier > 0)
|
||||
{
|
||||
for (int x = 0; x < dry_seq_break_max; ++x)
|
||||
for (int x = 0; x < inputs.dry_sequence_breakers_len; ++x)
|
||||
{
|
||||
std::string word = inputs.dry_sequence_breakers[x];
|
||||
if (word != "")
|
||||
|
|
86
koboldcpp.py
86
koboldcpp.py
|
@ -19,15 +19,17 @@ from datetime import datetime, timezone
|
|||
|
||||
# constants
|
||||
sampler_order_max = 7
|
||||
stop_token_max = 32
|
||||
ban_token_max = 64
|
||||
tensor_split_max = 16
|
||||
logit_bias_max = 32
|
||||
dry_seq_break_max = 24
|
||||
images_max = 4
|
||||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
|
||||
# abuse prevention
|
||||
stop_token_max = 512
|
||||
ban_token_max = 1024
|
||||
logit_bias_max = 1024
|
||||
dry_seq_break_max = 256
|
||||
|
||||
# global vars
|
||||
handle = None
|
||||
friendlymodelname = "inactive"
|
||||
|
@ -150,11 +152,6 @@ class generation_inputs(ctypes.Structure):
|
|||
("mirostat", ctypes.c_int),
|
||||
("mirostat_tau", ctypes.c_float),
|
||||
("mirostat_eta", ctypes.c_float),
|
||||
("dry_multiplier", ctypes.c_float),
|
||||
("dry_base", ctypes.c_float),
|
||||
("dry_allowed_length", ctypes.c_int),
|
||||
("dry_penalty_last_n", ctypes.c_int),
|
||||
("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max),
|
||||
("xtc_threshold", ctypes.c_float),
|
||||
("xtc_probability", ctypes.c_float),
|
||||
("sampler_order", ctypes.c_int * sampler_order_max),
|
||||
|
@ -162,7 +159,6 @@ class generation_inputs(ctypes.Structure):
|
|||
("allow_eos_token", ctypes.c_bool),
|
||||
("bypass_eos_token", ctypes.c_bool),
|
||||
("render_special", ctypes.c_bool),
|
||||
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
||||
("stream_sse", ctypes.c_bool),
|
||||
("grammar", ctypes.c_char_p),
|
||||
("grammar_retain_state", ctypes.c_bool),
|
||||
|
@ -170,8 +166,18 @@ class generation_inputs(ctypes.Structure):
|
|||
("dynatemp_range", ctypes.c_float),
|
||||
("dynatemp_exponent", ctypes.c_float),
|
||||
("smoothing_factor", ctypes.c_float),
|
||||
("logit_biases", logit_bias * logit_bias_max),
|
||||
("banned_tokens", ctypes.c_char_p * ban_token_max)]
|
||||
("dry_multiplier", ctypes.c_float),
|
||||
("dry_base", ctypes.c_float),
|
||||
("dry_allowed_length", ctypes.c_int),
|
||||
("dry_penalty_last_n", ctypes.c_int),
|
||||
("dry_sequence_breakers_len", ctypes.c_int),
|
||||
("dry_sequence_breakers", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("stop_sequence_len", ctypes.c_int),
|
||||
("stop_sequence", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("logit_biases_len", ctypes.c_int),
|
||||
("logit_biases", ctypes.POINTER(logit_bias)),
|
||||
("banned_tokens_len", ctypes.c_int),
|
||||
("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
|
||||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
@ -982,11 +988,16 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
except ValueError as e:
|
||||
print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e))
|
||||
dry_sequence_breakers = []
|
||||
for n in range(dry_seq_break_max):
|
||||
if dry_multiplier > 0 and n < len(dry_sequence_breakers):
|
||||
inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8")
|
||||
else:
|
||||
inputs.dry_sequence_breakers[n] = "".encode("UTF-8")
|
||||
|
||||
if dry_multiplier <= 0 or dry_sequence_breakers is None: # prevent explicitly set to None, retain old behavior
|
||||
dry_sequence_breakers = []
|
||||
|
||||
dry_sequence_breakers = dry_sequence_breakers[:dry_seq_break_max]
|
||||
inputs.dry_sequence_breakers_len = len(dry_sequence_breakers)
|
||||
inputs.dry_sequence_breakers = (ctypes.c_char_p * inputs.dry_sequence_breakers_len)()
|
||||
|
||||
for n, breaker in enumerate(dry_sequence_breakers):
|
||||
inputs.dry_sequence_breakers[n] = breaker.encode("UTF-8")
|
||||
|
||||
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
||||
try:
|
||||
|
@ -1000,13 +1011,18 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
except TypeError as e:
|
||||
print("ERROR: sampler_order must be a list of integers: " + str(e))
|
||||
inputs.seed = seed
|
||||
for n in range(stop_token_max):
|
||||
if not stop_sequence or n >= len(stop_sequence):
|
||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
elif stop_sequence[n]==None:
|
||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
|
||||
if stop_sequence is None:
|
||||
stop_sequence = []
|
||||
stop_sequence = stop_sequence[:stop_token_max]
|
||||
inputs.stop_sequence_len = len(stop_sequence)
|
||||
inputs.stop_sequence = (ctypes.c_char_p * inputs.stop_sequence_len)()
|
||||
|
||||
for n, sequence in enumerate(stop_sequence):
|
||||
if sequence:
|
||||
inputs.stop_sequence[n] = sequence.encode("UTF-8")
|
||||
else:
|
||||
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
|
||||
bias_list = []
|
||||
try:
|
||||
|
@ -1015,13 +1031,13 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
except Exception as ex:
|
||||
print(f"Logit bias dictionary is invalid: {ex}")
|
||||
|
||||
for n in range(logit_bias_max):
|
||||
if n >= len(bias_list):
|
||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
else:
|
||||
bias_list = bias_list[:logit_bias_max]
|
||||
inputs.logit_biases_len = len(bias_list)
|
||||
inputs.logit_biases = (logit_bias * inputs.logit_biases_len)()
|
||||
for n, lb in enumerate(bias_list):
|
||||
try:
|
||||
t_id = int(bias_list[n]['key'])
|
||||
bias = float(bias_list[n]['value'])
|
||||
t_id = int(lb['key'])
|
||||
bias = float(lb['value'])
|
||||
t_id = -1 if t_id < 0 else t_id
|
||||
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
||||
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
||||
|
@ -1029,11 +1045,13 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
print(f"Skipped unparsable logit bias:{ex}")
|
||||
|
||||
for n in range(ban_token_max):
|
||||
if not banned_tokens or n >= len(banned_tokens):
|
||||
inputs.banned_tokens[n] = "".encode("UTF-8")
|
||||
else:
|
||||
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
||||
if banned_tokens is None:
|
||||
banned_tokens = []
|
||||
banned_tokens = banned_tokens[:ban_token_max]
|
||||
inputs.banned_tokens_len = len(banned_tokens)
|
||||
inputs.banned_tokens = (ctypes.c_char_p * inputs.banned_tokens_len)()
|
||||
for n, tok in enumerate(banned_tokens):
|
||||
inputs.banned_tokens[n] = tok.encode("UTF-8")
|
||||
|
||||
currentusergenkey = genkey
|
||||
totalgens += 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue