diff --git a/common/common.h b/common/common.h index 1cb6c12b3..6aa6b5761 100644 --- a/common/common.h +++ b/common/common.h @@ -113,6 +113,11 @@ struct gpt_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + float dry_multiplier = 0.0f; // penalty multiplier, 0.0 = disabled + float dry_base = 1.75f; // exponential base + int32_t dry_allowed_length = 2; // repeated sequences longer than this are penalized + int32_t dry_penalty_last_n = 0; // how many tokens to scan for repetitions (0 = entire context) + std::vector dry_sequence_breakers; // DRY sequence breakers // DynaTemp! float dynatemp_range = 0.0f; // enables DynaTemp if greater than 0. dynatemp_min = temperature - dt_range, dynatemp_max = temperature + dt_range diff --git a/expose.h b/expose.h index 24b7b89a5..b620c1b2e 100644 --- a/expose.h +++ b/expose.h @@ -5,6 +5,7 @@ const int stop_token_max = 16; const int ban_token_max = 16; const int tensor_split_max = 16; const int logit_bias_max = 16; +const int dry_seq_break_max = 16; const int images_max = 4; // match kobold's sampler list and order @@ -83,6 +84,11 @@ struct generation_inputs const int mirostat = 0; const float mirostat_eta = 0.0f; const float mirostat_tau = 0.0f; + const float dry_multiplier = 0.0f; + const float dry_base = 0.0f; + const int dry_allowed_length = 0; + const int dry_penalty_last_n = 0; + const char * dry_sequence_breakers[dry_seq_break_max] = {}; const samplers sampler_order[KCPP_SAMPLER_MAX] = {}; const int sampler_len = 0; const bool allow_eos_token = false; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 0f78c088f..f5e4b3c3f 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "model_adapter.h" #include "otherarch.h" #include "grammar-parser.h" @@ -106,6 +107,9 @@ static std::vector stop_sequence; static std::vector special_stop_sequence; //for stop sequences that don't have a string representation static std::vector banned_tokens; static std::vector banned_token_ids; +static std::unordered_multimap> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token) +static std::vector dry_repeat_count; // Indexed as last_n_tokens +static std::unordered_map dry_max_token_repeat; static std::vector top_picks; static int remaining_tokens = 0; static int stopper_unused_tokens = 0; @@ -305,6 +309,75 @@ static void print_tok_vec_str(std::vector &vec) printf("\n%s", get_tok_vec_str(vec).c_str()); } +// Find tokens that completely contain `str`, either as a single token, or as a sequence of tokens. +// It's important to use a hash map for head tokens because some models have many of them. +// For example, the Llama 3 tokenizer has 6570 tokens containing the period ('.') character. +// Single tokens are allowed to extend past `str` at the front and back. This is to allow, for +// instance, the token '.\n' to be a head for both '.' and '\n'. However if a head token +// begins a multi-token sequence, the head can only extend past `str` at the beginning. The +// tail tokens are generated by tokenizing the remainder. +// If max_tail_len is >= 0, the maximum token length of a tail sequence is clamped to this value. +static void GetOverlappingTokenSequences(const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for(int v=0;vsecond.empty()) { + empty = true; + break; + } + } + if (!empty) { + token_sequences.emplace(v, std::vector()); + } + } else { + // Check whether a prefix of the string overlaps with a suffix of the token. + // Just do a naive O(N^2) search, since the worst case is limited by the + // maximum character length of a token in the vocabulary. + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + // We matched to the end of the string. Since `str` is not contained in `word`, + // there must be trailing letters in `str`. + std::vector tokenization; + TokenizeString(str.substr(i), tokenization, file_format, false); + if (max_tail_len >= 0 && tokenization.size() > max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization. + auto its = token_sequences.equal_range(v); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) + { + token_sequences.emplace(v, tokenization); + } + } + } + } + } +} llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng) { @@ -424,6 +497,208 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) candidates->size = last_idx; } +void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap>& restart_sequences, llama_token_data_array * candidates) { + if (penalty_multiplier == 0.0f || penalty_base == 0.0f) { + return; + } + if (penalty_range <= 0) { + penalty_range = n_ctx; + } + auto last_n_repeat = std::min(std::min((int)current_context_tokens.size(), penalty_range), n_ctx); + if (last_n_repeat <= allowed_length) { + return; + } + const llama_token * last_tokens = current_context_tokens.data() + current_context_tokens.size() - last_n_repeat; + + dry_repeat_count.assign(last_n_repeat, 0); + dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (size_t i = 0; i < last_n_repeat; ++i) { + size_t ix = last_n_repeat - 1 - i; + auto its = restart_sequences.equal_range(last_tokens[ix]); + if (its.first == restart_sequences.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= i) { + bool match = true; + for (size_t offset = 0; offset < seq_len; ++offset) { + // The +1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != last_tokens[ix + 1 + offset]) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = (int)i - longest_match; + break; + } + } + if (rep_limit <= allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && last_tokens[last - n] == last_tokens[last - (n+k)]) { + ++n; + } + dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k+n-1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (dry_repeat_count[last - p] < right_part_len) { + int n = std::min(dry_repeat_count[last - p], rep_limit); + dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat && last_tokens[last - i] == last_tokens[last - (i - k)]) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + dry_repeat_count[last - k] = n; + + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (size_t i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = dry_repeat_count[i]; + if (repeat_len >= allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + gpt_vocab::id token = last_tokens[i + 1]; + // Track the maximum sequence ending in this token. + const auto& it = dry_max_token_repeat.find(token); + if (it == dry_max_token_repeat.end() || it->second < repeat_len) { + dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (penalty_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(penalty_base); + } + + if (debugmode==1 && !dry_max_token_repeat.empty()) { + printf("DRY penalties ["); + } + size_t count = 0; + for (const auto& kvp: dry_max_token_repeat) { + gpt_vocab::id token = kvp.first; + int repeat_exp = kvp.second - allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = penalty_multiplier * pow(penalty_base, repeat_exp); + if (debugmode==1) + { + std::string tokenizedstr = FileFormatTokenizeID(token, file_format); + ::utreplace(tokenizedstr, "\n", "\\n"); + printf("%s(%s %.02f)", count == 0 ? "" : " ", RemoveBell(tokenizedstr).c_str(), penalty); + } + candidates->data[token].logit -= penalty; + ++count; + } + if (debugmode==1 && !dry_max_token_repeat.empty()) { + printf("]\n"); + } +} + void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, llama_token_data_array * candidates_p) { auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); @@ -539,7 +814,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar } int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng, -int mirostat, float mirostat_tau, float mirostat_eta, const std::vector & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor) +int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, const std::vector & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor) { int id = 0; std::vector candidates; @@ -615,6 +890,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vectormirostat = inputs.mirostat; kcpp_params->mirostat_eta = inputs.mirostat_eta; kcpp_params->mirostat_tau = inputs.mirostat_tau; + kcpp_params->dry_multiplier = inputs.dry_multiplier; + kcpp_params->dry_base = inputs.dry_base; + kcpp_params->dry_allowed_length = inputs.dry_allowed_length; + kcpp_params->dry_penalty_last_n = inputs.dry_penalty_last_n; kcpp_params->dynatemp_range = inputs.dynatemp_range; kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent; kcpp_params->n_ctx = inputs.max_context_length; kcpp_params->smoothing_factor = inputs.smoothing_factor; + // Parse dry sequence breakers / restart sequences + kcpp_params->dry_sequence_breakers.clear(); + for(int x=0;xdry_sequence_breakers.push_back(word); + } + } + dry_sequence_breakers.clear(); + if(kcpp_params->dry_sequence_breakers.size()>0) { + // Restrict the maximum length of sequences used as sequence breakers. There are + // very few use cases for a long sequence breaker, and limiting the max length + // prevents a potential denial of service attack in which long repetitive sequence + // breakers could result in slow DRY sampling with a suitably crafted context. + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + if(debugmode==1) { + printf("\nProcessing %zu dry break strings...",kcpp_params->dry_sequence_breakers.size()); + } + for (auto sequence_break: kcpp_params->dry_sequence_breakers) { + if (sequence_break.size() > MAX_CHAR_LEN) { + sequence_break.resize(MAX_CHAR_LEN); + } + GetOverlappingTokenSequences(sequence_break, dry_sequence_breakers, MAX_SEQ_LEN); + } + if(debugmode==1) { + int trivial = 0, non_trivial = 0; + for (const auto& seq: dry_sequence_breakers) { + if (seq.second.empty()) { + ++trivial; + } else { + ++non_trivial; + } + } + printf("\nFound a total of %zu restart heads, %d trivial, %d non-trivial.\n", dry_sequence_breakers.size(), trivial, non_trivial); + } + } + bool stream_sse = inputs.stream_sse; bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1; @@ -2303,7 +2622,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_params->rep_pen_slope, presence_penalty, top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng, - kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor); + kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, + kcpp_params->dry_multiplier, kcpp_params->dry_base, + kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n, + sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor); if (grammar != nullptr) { grammar_accept_token(file_format, n_vocab, grammar, id); diff --git a/koboldcpp.py b/koboldcpp.py index 13369bbfc..5a8ddde29 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -21,6 +21,7 @@ stop_token_max = 16 ban_token_max = 16 tensor_split_max = 16 logit_bias_max = 16 +dry_seq_break_max = 16 images_max = 4 bias_min_value = -100.0 bias_max_value = 100.0 @@ -84,6 +85,11 @@ class generation_inputs(ctypes.Structure): ("mirostat", ctypes.c_int), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), + ("dry_multiplier", ctypes.c_float), + ("dry_base", ctypes.c_float), + ("dry_allowed_length", ctypes.c_int), + ("dry_penalty_last_n", ctypes.c_int), + ("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max), ("sampler_order", ctypes.c_int * sampler_order_max), ("sampler_len", ctypes.c_int), ("allow_eos_token", ctypes.c_bool), @@ -485,7 +491,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): +def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=['\n', ':', '"', '*'], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): global maxctx, args, currentusergenkey, totalgens, pendingabortkey inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8") @@ -533,6 +539,24 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 inputs.mirostat_eta = mirostat_eta else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 + inputs.dry_multiplier = dry_multiplier + inputs.dry_base = dry_base + inputs.dry_allowed_length = dry_allowed_length + inputs.dry_penalty_last_n = dry_penalty_last_n + # Handle dry_sequence_breakers being passed as a json-encoded array of + # strings, rather than as an array of strings itself. This is to support + # SillyTavern, which passes sequence breakers to Oobabooga that way. + if isinstance(dry_sequence_breakers, str): + try: + dry_sequence_breakers = json.loads(dry_sequence_breakers) + except ValueError as e: + print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e)) + dry_sequence_breakers = [] + for n in range(dry_seq_break_max): + if n < len(dry_sequence_breakers): + inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8") + else: + inputs.dry_sequence_breakers[n] = "".encode("UTF-8") if sampler_order and 0 < len(sampler_order) <= sampler_order_max: try: for i, sampler in enumerate(sampler_order): @@ -967,6 +991,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): mirostat=genparams.get('mirostat', 0), mirostat_tau=genparams.get('mirostat_tau', 5.0), mirostat_eta=genparams.get('mirostat_eta', 0.1), + dry_multiplier=genparams.get('dry_multiplier', 0.0), + dry_base=genparams.get('dry_base', 1.75), + dry_allowed_length=genparams.get('dry_allowed_length', 2), + dry_penalty_last_n=genparams.get('dry_penalty_last_n', 0), + dry_sequence_breakers=genparams.get('dry_sequence_breakers', []), sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), seed=tryparseint(genparams.get('sampler_seed', -1)), stop_sequence=genparams.get('stop_sequence', []),