mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Add the DRY dynamic N-gram anti-repetition sampler (#982)
* Add the DRY dynamic N-gram anti-repetition sampler The DRY (Do not Repeat Yourself) sampler is a dynamic N-gram repetition penalty that negatively scores tokens that would extend sequences that already appear in the context. See this discussion for a motivation and explanation of the sampler: https://github.com/oobabooga/text-generation-webui/pull/5677 This implementation of DRY mostly aligns with the obabooga version with a few modifications. It uses a more efficient linear scanning algorithm to identify repetitions. It also supports multi-token sequence breakers. As a limitation, this implementation reuses the rep pen range parameter, rather than introducing a new range just for the DRY sampler. There is a separate change to lite.koboldai.net that exposes the DRY sampler parameters to KoboldAI Lite, so none of the embed files have been changed as part of this commit. * Update default DRY parameters to match lite * Improve DRY token debug logging * Replace `and` with `&&` to fix MSVC compile error Little known fact: The C++98 standard defines `and` as an alternative token for the `&&` operator (along with a bunch of other digraphs). MSVC does not allow these without using the /Za option or including the <iso646.h> header. Change to the more standard operator to make this code more portable. * Fix MSVC compile error because log is not constexpr Replace the compile-time computation with a floating-point approximation of log(std::numeric_limits<float>::max()). * Remove unused llama sampler variables and clean up sequence breakers. * Remove KCPP_SAMPLER_DRY as a separate enum entry The DRY sampler is effectively a repetition penalty and there are very few reasons to apply it at a different place in sampler order than the standard single-token penalty. There are also multiple projects that have dependencies on the existing sampler IDs, including KoboldAI, KoboldAI Lite, and Silly Tavern. In order to minimize the impact of the dependencies of adding the DRY sampler to koboldcpp, it makes the most sense to not add a new ID for now, and instead to piggyback on KCPP_SAMPLER_REP_PEN. In the future if we find a use case for splitting the application of rep pen and DRY we can introduce a new enum entry then. * Add the dry_penalty_last_n to independently control DRY penalty range This parameter follows the oobabooga semantics: it's optional, with a default value of zero. Zero means that DRY should sample the entire context. Otherwise, it's the number of tokens from the end of the context that are scanned for repetitions. * Limit sequence breaker lengths in tokens and characters The core DRY sampler algorithm is linear in the context length, but there are several parts of the sampler related to multi-token sequence breakers that are potentially quadratic. Without any restrictions, a suitably crafted context and sequence breaker could result in a denial-of-service attack on a server running koboldcpp. This change limits the maximum number of characters and the maximum token length of a sequence breaker in order to limit the maximum overhead associated with the sampler. This change also improves some comments, adding more detail and changing the wording to increase clarity.
This commit is contained in:
parent
add0a88111
commit
264575426e
4 changed files with 365 additions and 3 deletions
31
koboldcpp.py
31
koboldcpp.py
|
@ -21,6 +21,7 @@ stop_token_max = 16
|
|||
ban_token_max = 16
|
||||
tensor_split_max = 16
|
||||
logit_bias_max = 16
|
||||
dry_seq_break_max = 16
|
||||
images_max = 4
|
||||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
|
@ -84,6 +85,11 @@ class generation_inputs(ctypes.Structure):
|
|||
("mirostat", ctypes.c_int),
|
||||
("mirostat_tau", ctypes.c_float),
|
||||
("mirostat_eta", ctypes.c_float),
|
||||
("dry_multiplier", ctypes.c_float),
|
||||
("dry_base", ctypes.c_float),
|
||||
("dry_allowed_length", ctypes.c_int),
|
||||
("dry_penalty_last_n", ctypes.c_int),
|
||||
("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max),
|
||||
("sampler_order", ctypes.c_int * sampler_order_max),
|
||||
("sampler_len", ctypes.c_int),
|
||||
("allow_eos_token", ctypes.c_bool),
|
||||
|
@ -485,7 +491,7 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
|
||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=['\n', ':', '"', '*'], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
|
||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||
inputs = generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
|
@ -533,6 +539,24 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
|
|||
inputs.mirostat_eta = mirostat_eta
|
||||
else:
|
||||
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
|
||||
inputs.dry_multiplier = dry_multiplier
|
||||
inputs.dry_base = dry_base
|
||||
inputs.dry_allowed_length = dry_allowed_length
|
||||
inputs.dry_penalty_last_n = dry_penalty_last_n
|
||||
# Handle dry_sequence_breakers being passed as a json-encoded array of
|
||||
# strings, rather than as an array of strings itself. This is to support
|
||||
# SillyTavern, which passes sequence breakers to Oobabooga that way.
|
||||
if isinstance(dry_sequence_breakers, str):
|
||||
try:
|
||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||
except ValueError as e:
|
||||
print(f"ERROR: dry_sequence_breakers must be an array of strings or a json encoded array of strings. Could not parse '{dry_sequence_breakers}': " + str(e))
|
||||
dry_sequence_breakers = []
|
||||
for n in range(dry_seq_break_max):
|
||||
if n < len(dry_sequence_breakers):
|
||||
inputs.dry_sequence_breakers[n] = dry_sequence_breakers[n].encode("UTF-8")
|
||||
else:
|
||||
inputs.dry_sequence_breakers[n] = "".encode("UTF-8")
|
||||
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
||||
try:
|
||||
for i, sampler in enumerate(sampler_order):
|
||||
|
@ -967,6 +991,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
mirostat=genparams.get('mirostat', 0),
|
||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||
dry_multiplier=genparams.get('dry_multiplier', 0.0),
|
||||
dry_base=genparams.get('dry_base', 1.75),
|
||||
dry_allowed_length=genparams.get('dry_allowed_length', 2),
|
||||
dry_penalty_last_n=genparams.get('dry_penalty_last_n', 0),
|
||||
dry_sequence_breakers=genparams.get('dry_sequence_breakers', []),
|
||||
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
||||
seed=tryparseint(genparams.get('sampler_seed', -1)),
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue