diff --git a/expose.h b/expose.h index a84284f2e..97c4a38ab 100644 --- a/expose.h +++ b/expose.h @@ -1,11 +1,7 @@ #pragma once #include -const int stop_token_max = 32; -const int ban_token_max = 64; const int tensor_split_max = 16; -const int logit_bias_max = 32; -const int dry_seq_break_max = 24; const int images_max = 4; // match kobold's sampler list and order @@ -84,11 +80,6 @@ struct generation_inputs const int mirostat = 0; const float mirostat_eta = 0.0f; const float mirostat_tau = 0.0f; - const float dry_multiplier = 0.0f; - const float dry_base = 0.0f; - const int dry_allowed_length = 0; - const int dry_penalty_last_n = 0; - const char * dry_sequence_breakers[dry_seq_break_max] = {}; const float xtc_threshold = 0.0f; const float xtc_probability = 0.0f; const samplers sampler_order[KCPP_SAMPLER_MAX] = {}; @@ -96,7 +87,6 @@ struct generation_inputs const bool allow_eos_token = false; const bool bypass_eos_token = false; const bool render_special = false; - const char * stop_sequence[stop_token_max] = {}; const bool stream_sse = false; const char * grammar = nullptr; const bool grammar_retain_state = false; @@ -104,8 +94,18 @@ struct generation_inputs const float dynatemp_range = 0.0f; const float dynatemp_exponent = 1.0f; const float smoothing_factor = 0.0f; - const logit_bias logit_biases[logit_bias_max] = {}; - const char * banned_tokens[ban_token_max] = {}; + const float dry_multiplier = 0.0f; + const float dry_base = 0.0f; + const int dry_allowed_length = 0; + const int dry_penalty_last_n = 0; + const int dry_sequence_breakers_len = 0; + const char ** dry_sequence_breakers = nullptr; + const int stop_sequence_len = 0; + const char ** stop_sequence = nullptr; + const int logit_biases_len = 0; + const logit_bias * logit_biases = nullptr; + const int banned_tokens_len = 0; + const char ** banned_tokens = nullptr; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index bdda33d29..501748dd8 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -2488,7 +2488,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) double time0 = 0, time1 = 0, time2 = 0; timer_start(); - for(int x=0;xdry_multiplier > 0) { - for (int x = 0; x < dry_seq_break_max; ++x) + for (int x = 0; x < inputs.dry_sequence_breakers_len; ++x) { std::string word = inputs.dry_sequence_breakers[x]; if (word != "") diff --git a/koboldcpp.py b/koboldcpp.py index e5f785ccc..42feb4b50 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -19,15 +19,17 @@ from datetime import datetime, timezone # constants sampler_order_max = 7 -stop_token_max = 32 -ban_token_max = 64 tensor_split_max = 16 -logit_bias_max = 32 -dry_seq_break_max = 24 images_max = 4 bias_min_value = -100.0 bias_max_value = 100.0 +# abuse prevention +stop_token_max = 512 +ban_token_max = 1024 +logit_bias_max = 1024 +dry_seq_break_max = 256 + # global vars handle = None friendlymodelname = "inactive" @@ -150,11 +152,6 @@ class generation_inputs(ctypes.Structure): ("mirostat", ctypes.c_int), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), - ("dry_multiplier", ctypes.c_float), - ("dry_base", ctypes.c_float), - ("dry_allowed_length", ctypes.c_int), - ("dry_penalty_last_n", ctypes.c_int), - ("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max), ("xtc_threshold", ctypes.c_float), ("xtc_probability", ctypes.c_float), ("sampler_order", ctypes.c_int * sampler_order_max), @@ -162,7 +159,6 @@ class generation_inputs(ctypes.Structure): ("allow_eos_token", ctypes.c_bool), ("bypass_eos_token", ctypes.c_bool), ("render_special", ctypes.c_bool), - ("stop_sequence", ctypes.c_char_p * stop_token_max), ("stream_sse", ctypes.c_bool), ("grammar", ctypes.c_char_p), ("grammar_retain_state", ctypes.c_bool), @@ -170,8 +166,18 @@ class generation_inputs(ctypes.Structure): ("dynatemp_range", ctypes.c_float), ("dynatemp_exponent", ctypes.c_float), ("smoothing_factor", ctypes.c_float), - ("logit_biases", logit_bias * logit_bias_max), - ("banned_tokens", ctypes.c_char_p * ban_token_max)] + ("dry_multiplier", ctypes.c_float), + ("dry_base", ctypes.c_float), + ("dry_allowed_length", ctypes.c_int), + ("dry_penalty_last_n", ctypes.c_int), + ("dry_sequence_breakers_len", ctypes.c_int), + ("dry_sequence_breakers", ctypes.POINTER(ctypes.c_char_p)), + ("stop_sequence_len", ctypes.c_int), + ("stop_sequence", ctypes.POINTER(ctypes.c_char_p)), + ("logit_biases_len", ctypes.c_int), + ("logit_biases", ctypes.POINTER(logit_bias)), + ("banned_tokens_len", ctypes.c_int), + ("banned_tokens", ctypes.POINTER(ctypes.c_char_p))] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -982,11 +988,16 @@ def generate(genparams, is_quiet=False, stream_flag=False): except ValueError as e: print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e)) dry_sequence_breakers = [] - for n in range(dry_seq_break_max): - if dry_multiplier > 0 and n < len(dry_sequence_breakers): - inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8") - else: - inputs.dry_sequence_breakers[n] = "".encode("UTF-8") + + if dry_multiplier <= 0 or dry_sequence_breakers is None: # prevent explicitly set to None, retain old behavior + dry_sequence_breakers = [] + + dry_sequence_breakers = dry_sequence_breakers[:dry_seq_break_max] + inputs.dry_sequence_breakers_len = len(dry_sequence_breakers) + inputs.dry_sequence_breakers = (ctypes.c_char_p * inputs.dry_sequence_breakers_len)() + + for n, breaker in enumerate(dry_sequence_breakers): + inputs.dry_sequence_breakers[n] = breaker.encode("UTF-8") if sampler_order and 0 < len(sampler_order) <= sampler_order_max: try: @@ -1000,13 +1011,18 @@ def generate(genparams, is_quiet=False, stream_flag=False): except TypeError as e: print("ERROR: sampler_order must be a list of integers: " + str(e)) inputs.seed = seed - for n in range(stop_token_max): - if not stop_sequence or n >= len(stop_sequence): - inputs.stop_sequence[n] = "".encode("UTF-8") - elif stop_sequence[n]==None: - inputs.stop_sequence[n] = "".encode("UTF-8") + + if stop_sequence is None: + stop_sequence = [] + stop_sequence = stop_sequence[:stop_token_max] + inputs.stop_sequence_len = len(stop_sequence) + inputs.stop_sequence = (ctypes.c_char_p * inputs.stop_sequence_len)() + + for n, sequence in enumerate(stop_sequence): + if sequence: + inputs.stop_sequence[n] = sequence.encode("UTF-8") else: - inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8") + inputs.stop_sequence[n] = "".encode("UTF-8") bias_list = [] try: @@ -1015,25 +1031,27 @@ def generate(genparams, is_quiet=False, stream_flag=False): except Exception as ex: print(f"Logit bias dictionary is invalid: {ex}") - for n in range(logit_bias_max): - if n >= len(bias_list): + bias_list = bias_list[:logit_bias_max] + inputs.logit_biases_len = len(bias_list) + inputs.logit_biases = (logit_bias * inputs.logit_biases_len)() + for n, lb in enumerate(bias_list): + try: + t_id = int(lb['key']) + bias = float(lb['value']) + t_id = -1 if t_id < 0 else t_id + bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias)) + inputs.logit_biases[n] = logit_bias(t_id, bias) + except Exception as ex: inputs.logit_biases[n] = logit_bias(-1, 0.0) - else: - try: - t_id = int(bias_list[n]['key']) - bias = float(bias_list[n]['value']) - t_id = -1 if t_id < 0 else t_id - bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias)) - inputs.logit_biases[n] = logit_bias(t_id, bias) - except Exception as ex: - inputs.logit_biases[n] = logit_bias(-1, 0.0) - print(f"Skipped unparsable logit bias:{ex}") + print(f"Skipped unparsable logit bias:{ex}") - for n in range(ban_token_max): - if not banned_tokens or n >= len(banned_tokens): - inputs.banned_tokens[n] = "".encode("UTF-8") - else: - inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8") + if banned_tokens is None: + banned_tokens = [] + banned_tokens = banned_tokens[:ban_token_max] + inputs.banned_tokens_len = len(banned_tokens) + inputs.banned_tokens = (ctypes.c_char_p * inputs.banned_tokens_len)() + for n, tok in enumerate(banned_tokens): + inputs.banned_tokens[n] = tok.encode("UTF-8") currentusergenkey = genkey totalgens += 1