mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Dynamic sizes for sequences (#1157)
* Dynamic sizes for sequences * cleanup PR - move all dynamic fields to end of payload, ensure correct null handling to match existing behavior, add anti abuse limit of max 512 for dynamic fields * adjust anti abuse limits --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
7f76425450
commit
8bb220329c
3 changed files with 74 additions and 56 deletions
24
expose.h
24
expose.h
|
@ -1,11 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
const int stop_token_max = 32;
|
|
||||||
const int ban_token_max = 64;
|
|
||||||
const int tensor_split_max = 16;
|
const int tensor_split_max = 16;
|
||||||
const int logit_bias_max = 32;
|
|
||||||
const int dry_seq_break_max = 24;
|
|
||||||
const int images_max = 4;
|
const int images_max = 4;
|
||||||
|
|
||||||
// match kobold's sampler list and order
|
// match kobold's sampler list and order
|
||||||
|
@ -84,11 +80,6 @@ struct generation_inputs
|
||||||
const int mirostat = 0;
|
const int mirostat = 0;
|
||||||
const float mirostat_eta = 0.0f;
|
const float mirostat_eta = 0.0f;
|
||||||
const float mirostat_tau = 0.0f;
|
const float mirostat_tau = 0.0f;
|
||||||
const float dry_multiplier = 0.0f;
|
|
||||||
const float dry_base = 0.0f;
|
|
||||||
const int dry_allowed_length = 0;
|
|
||||||
const int dry_penalty_last_n = 0;
|
|
||||||
const char * dry_sequence_breakers[dry_seq_break_max] = {};
|
|
||||||
const float xtc_threshold = 0.0f;
|
const float xtc_threshold = 0.0f;
|
||||||
const float xtc_probability = 0.0f;
|
const float xtc_probability = 0.0f;
|
||||||
const samplers sampler_order[KCPP_SAMPLER_MAX] = {};
|
const samplers sampler_order[KCPP_SAMPLER_MAX] = {};
|
||||||
|
@ -96,7 +87,6 @@ struct generation_inputs
|
||||||
const bool allow_eos_token = false;
|
const bool allow_eos_token = false;
|
||||||
const bool bypass_eos_token = false;
|
const bool bypass_eos_token = false;
|
||||||
const bool render_special = false;
|
const bool render_special = false;
|
||||||
const char * stop_sequence[stop_token_max] = {};
|
|
||||||
const bool stream_sse = false;
|
const bool stream_sse = false;
|
||||||
const char * grammar = nullptr;
|
const char * grammar = nullptr;
|
||||||
const bool grammar_retain_state = false;
|
const bool grammar_retain_state = false;
|
||||||
|
@ -104,8 +94,18 @@ struct generation_inputs
|
||||||
const float dynatemp_range = 0.0f;
|
const float dynatemp_range = 0.0f;
|
||||||
const float dynatemp_exponent = 1.0f;
|
const float dynatemp_exponent = 1.0f;
|
||||||
const float smoothing_factor = 0.0f;
|
const float smoothing_factor = 0.0f;
|
||||||
const logit_bias logit_biases[logit_bias_max] = {};
|
const float dry_multiplier = 0.0f;
|
||||||
const char * banned_tokens[ban_token_max] = {};
|
const float dry_base = 0.0f;
|
||||||
|
const int dry_allowed_length = 0;
|
||||||
|
const int dry_penalty_last_n = 0;
|
||||||
|
const int dry_sequence_breakers_len = 0;
|
||||||
|
const char ** dry_sequence_breakers = nullptr;
|
||||||
|
const int stop_sequence_len = 0;
|
||||||
|
const char ** stop_sequence = nullptr;
|
||||||
|
const int logit_biases_len = 0;
|
||||||
|
const logit_bias * logit_biases = nullptr;
|
||||||
|
const int banned_tokens_len = 0;
|
||||||
|
const char ** banned_tokens = nullptr;
|
||||||
};
|
};
|
||||||
struct generation_outputs
|
struct generation_outputs
|
||||||
{
|
{
|
||||||
|
|
|
@ -2488,7 +2488,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
double time0 = 0, time1 = 0, time2 = 0;
|
double time0 = 0, time1 = 0, time2 = 0;
|
||||||
timer_start();
|
timer_start();
|
||||||
|
|
||||||
for(int x=0;x<stop_token_max;++x)
|
for(int x=0;x<inputs.stop_sequence_len;++x)
|
||||||
{
|
{
|
||||||
std::string stopper = inputs.stop_sequence[x];
|
std::string stopper = inputs.stop_sequence[x];
|
||||||
if(stopper!="")
|
if(stopper!="")
|
||||||
|
@ -2516,7 +2516,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
delayed_generated_tokens_limit = 0;
|
delayed_generated_tokens_limit = 0;
|
||||||
antislop_banned_token_ids.clear();
|
antislop_banned_token_ids.clear();
|
||||||
banned_tokens.clear();
|
banned_tokens.clear();
|
||||||
for(int x=0;x<ban_token_max;++x)
|
for(int x=0;x<inputs.banned_tokens_len;++x)
|
||||||
{
|
{
|
||||||
std::string word = inputs.banned_tokens[x];
|
std::string word = inputs.banned_tokens[x];
|
||||||
word = toLowerCase(word);
|
word = toLowerCase(word);
|
||||||
|
@ -2574,7 +2574,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
}
|
}
|
||||||
|
|
||||||
logit_biases.clear();
|
logit_biases.clear();
|
||||||
for(int x=0;x<logit_bias_max;++x)
|
for(int x=0;x<inputs.logit_biases_len;++x)
|
||||||
{
|
{
|
||||||
int32_t t_id = inputs.logit_biases[x].token_id;
|
int32_t t_id = inputs.logit_biases[x].token_id;
|
||||||
float bias = inputs.logit_biases[x].bias;
|
float bias = inputs.logit_biases[x].bias;
|
||||||
|
@ -2652,7 +2652,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
|
|
||||||
if (kcpp_data->dry_multiplier > 0)
|
if (kcpp_data->dry_multiplier > 0)
|
||||||
{
|
{
|
||||||
for (int x = 0; x < dry_seq_break_max; ++x)
|
for (int x = 0; x < inputs.dry_sequence_breakers_len; ++x)
|
||||||
{
|
{
|
||||||
std::string word = inputs.dry_sequence_breakers[x];
|
std::string word = inputs.dry_sequence_breakers[x];
|
||||||
if (word != "")
|
if (word != "")
|
||||||
|
|
98
koboldcpp.py
98
koboldcpp.py
|
@ -19,15 +19,17 @@ from datetime import datetime, timezone
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
sampler_order_max = 7
|
sampler_order_max = 7
|
||||||
stop_token_max = 32
|
|
||||||
ban_token_max = 64
|
|
||||||
tensor_split_max = 16
|
tensor_split_max = 16
|
||||||
logit_bias_max = 32
|
|
||||||
dry_seq_break_max = 24
|
|
||||||
images_max = 4
|
images_max = 4
|
||||||
bias_min_value = -100.0
|
bias_min_value = -100.0
|
||||||
bias_max_value = 100.0
|
bias_max_value = 100.0
|
||||||
|
|
||||||
|
# abuse prevention
|
||||||
|
stop_token_max = 512
|
||||||
|
ban_token_max = 1024
|
||||||
|
logit_bias_max = 1024
|
||||||
|
dry_seq_break_max = 256
|
||||||
|
|
||||||
# global vars
|
# global vars
|
||||||
handle = None
|
handle = None
|
||||||
friendlymodelname = "inactive"
|
friendlymodelname = "inactive"
|
||||||
|
@ -150,11 +152,6 @@ class generation_inputs(ctypes.Structure):
|
||||||
("mirostat", ctypes.c_int),
|
("mirostat", ctypes.c_int),
|
||||||
("mirostat_tau", ctypes.c_float),
|
("mirostat_tau", ctypes.c_float),
|
||||||
("mirostat_eta", ctypes.c_float),
|
("mirostat_eta", ctypes.c_float),
|
||||||
("dry_multiplier", ctypes.c_float),
|
|
||||||
("dry_base", ctypes.c_float),
|
|
||||||
("dry_allowed_length", ctypes.c_int),
|
|
||||||
("dry_penalty_last_n", ctypes.c_int),
|
|
||||||
("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max),
|
|
||||||
("xtc_threshold", ctypes.c_float),
|
("xtc_threshold", ctypes.c_float),
|
||||||
("xtc_probability", ctypes.c_float),
|
("xtc_probability", ctypes.c_float),
|
||||||
("sampler_order", ctypes.c_int * sampler_order_max),
|
("sampler_order", ctypes.c_int * sampler_order_max),
|
||||||
|
@ -162,7 +159,6 @@ class generation_inputs(ctypes.Structure):
|
||||||
("allow_eos_token", ctypes.c_bool),
|
("allow_eos_token", ctypes.c_bool),
|
||||||
("bypass_eos_token", ctypes.c_bool),
|
("bypass_eos_token", ctypes.c_bool),
|
||||||
("render_special", ctypes.c_bool),
|
("render_special", ctypes.c_bool),
|
||||||
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
|
||||||
("stream_sse", ctypes.c_bool),
|
("stream_sse", ctypes.c_bool),
|
||||||
("grammar", ctypes.c_char_p),
|
("grammar", ctypes.c_char_p),
|
||||||
("grammar_retain_state", ctypes.c_bool),
|
("grammar_retain_state", ctypes.c_bool),
|
||||||
|
@ -170,8 +166,18 @@ class generation_inputs(ctypes.Structure):
|
||||||
("dynatemp_range", ctypes.c_float),
|
("dynatemp_range", ctypes.c_float),
|
||||||
("dynatemp_exponent", ctypes.c_float),
|
("dynatemp_exponent", ctypes.c_float),
|
||||||
("smoothing_factor", ctypes.c_float),
|
("smoothing_factor", ctypes.c_float),
|
||||||
("logit_biases", logit_bias * logit_bias_max),
|
("dry_multiplier", ctypes.c_float),
|
||||||
("banned_tokens", ctypes.c_char_p * ban_token_max)]
|
("dry_base", ctypes.c_float),
|
||||||
|
("dry_allowed_length", ctypes.c_int),
|
||||||
|
("dry_penalty_last_n", ctypes.c_int),
|
||||||
|
("dry_sequence_breakers_len", ctypes.c_int),
|
||||||
|
("dry_sequence_breakers", ctypes.POINTER(ctypes.c_char_p)),
|
||||||
|
("stop_sequence_len", ctypes.c_int),
|
||||||
|
("stop_sequence", ctypes.POINTER(ctypes.c_char_p)),
|
||||||
|
("logit_biases_len", ctypes.c_int),
|
||||||
|
("logit_biases", ctypes.POINTER(logit_bias)),
|
||||||
|
("banned_tokens_len", ctypes.c_int),
|
||||||
|
("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
|
||||||
|
|
||||||
class generation_outputs(ctypes.Structure):
|
class generation_outputs(ctypes.Structure):
|
||||||
_fields_ = [("status", ctypes.c_int),
|
_fields_ = [("status", ctypes.c_int),
|
||||||
|
@ -982,11 +988,16 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e))
|
print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e))
|
||||||
dry_sequence_breakers = []
|
dry_sequence_breakers = []
|
||||||
for n in range(dry_seq_break_max):
|
|
||||||
if dry_multiplier > 0 and n < len(dry_sequence_breakers):
|
if dry_multiplier <= 0 or dry_sequence_breakers is None: # prevent explicitly set to None, retain old behavior
|
||||||
inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8")
|
dry_sequence_breakers = []
|
||||||
else:
|
|
||||||
inputs.dry_sequence_breakers[n] = "".encode("UTF-8")
|
dry_sequence_breakers = dry_sequence_breakers[:dry_seq_break_max]
|
||||||
|
inputs.dry_sequence_breakers_len = len(dry_sequence_breakers)
|
||||||
|
inputs.dry_sequence_breakers = (ctypes.c_char_p * inputs.dry_sequence_breakers_len)()
|
||||||
|
|
||||||
|
for n, breaker in enumerate(dry_sequence_breakers):
|
||||||
|
inputs.dry_sequence_breakers[n] = breaker.encode("UTF-8")
|
||||||
|
|
||||||
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
||||||
try:
|
try:
|
||||||
|
@ -1000,13 +1011,18 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
print("ERROR: sampler_order must be a list of integers: " + str(e))
|
print("ERROR: sampler_order must be a list of integers: " + str(e))
|
||||||
inputs.seed = seed
|
inputs.seed = seed
|
||||||
for n in range(stop_token_max):
|
|
||||||
if not stop_sequence or n >= len(stop_sequence):
|
if stop_sequence is None:
|
||||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
stop_sequence = []
|
||||||
elif stop_sequence[n]==None:
|
stop_sequence = stop_sequence[:stop_token_max]
|
||||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
inputs.stop_sequence_len = len(stop_sequence)
|
||||||
|
inputs.stop_sequence = (ctypes.c_char_p * inputs.stop_sequence_len)()
|
||||||
|
|
||||||
|
for n, sequence in enumerate(stop_sequence):
|
||||||
|
if sequence:
|
||||||
|
inputs.stop_sequence[n] = sequence.encode("UTF-8")
|
||||||
else:
|
else:
|
||||||
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||||
|
|
||||||
bias_list = []
|
bias_list = []
|
||||||
try:
|
try:
|
||||||
|
@ -1015,25 +1031,27 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(f"Logit bias dictionary is invalid: {ex}")
|
print(f"Logit bias dictionary is invalid: {ex}")
|
||||||
|
|
||||||
for n in range(logit_bias_max):
|
bias_list = bias_list[:logit_bias_max]
|
||||||
if n >= len(bias_list):
|
inputs.logit_biases_len = len(bias_list)
|
||||||
|
inputs.logit_biases = (logit_bias * inputs.logit_biases_len)()
|
||||||
|
for n, lb in enumerate(bias_list):
|
||||||
|
try:
|
||||||
|
t_id = int(lb['key'])
|
||||||
|
bias = float(lb['value'])
|
||||||
|
t_id = -1 if t_id < 0 else t_id
|
||||||
|
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
||||||
|
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
||||||
|
except Exception as ex:
|
||||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||||
else:
|
print(f"Skipped unparsable logit bias:{ex}")
|
||||||
try:
|
|
||||||
t_id = int(bias_list[n]['key'])
|
|
||||||
bias = float(bias_list[n]['value'])
|
|
||||||
t_id = -1 if t_id < 0 else t_id
|
|
||||||
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
|
||||||
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
|
||||||
except Exception as ex:
|
|
||||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
|
||||||
print(f"Skipped unparsable logit bias:{ex}")
|
|
||||||
|
|
||||||
for n in range(ban_token_max):
|
if banned_tokens is None:
|
||||||
if not banned_tokens or n >= len(banned_tokens):
|
banned_tokens = []
|
||||||
inputs.banned_tokens[n] = "".encode("UTF-8")
|
banned_tokens = banned_tokens[:ban_token_max]
|
||||||
else:
|
inputs.banned_tokens_len = len(banned_tokens)
|
||||||
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
inputs.banned_tokens = (ctypes.c_char_p * inputs.banned_tokens_len)()
|
||||||
|
for n, tok in enumerate(banned_tokens):
|
||||||
|
inputs.banned_tokens[n] = tok.encode("UTF-8")
|
||||||
|
|
||||||
currentusergenkey = genkey
|
currentusergenkey = genkey
|
||||||
totalgens += 1
|
totalgens += 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue