mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
fixed DRY
This commit is contained in:
parent
2cf6d16c40
commit
cd69ab218e
3 changed files with 10 additions and 4 deletions
|
@ -505,7 +505,7 @@ void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float pe
|
||||||
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) {
|
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (penalty_range <= 0) {
|
if (penalty_range <= 0 || penalty_range>n_ctx) {
|
||||||
penalty_range = n_ctx;
|
penalty_range = n_ctx;
|
||||||
}
|
}
|
||||||
auto last_n_repeat = std::min(std::min((int)current_context_tokens.size(), penalty_range), n_ctx);
|
auto last_n_repeat = std::min(std::min((int)current_context_tokens.size(), penalty_range), n_ctx);
|
||||||
|
@ -843,6 +843,8 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
|
||||||
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
|
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
|
||||||
|
|
||||||
//prefilter to top 5k tokens for improved speed
|
//prefilter to top 5k tokens for improved speed
|
||||||
llama_sample_top_k(nullptr, &candidates_p, 5000, 1);
|
llama_sample_top_k(nullptr, &candidates_p, 5000, 1);
|
||||||
|
|
||||||
|
@ -901,7 +903,6 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
|
||||||
break;
|
break;
|
||||||
case KCPP_SAMPLER_REP_PEN:
|
case KCPP_SAMPLER_REP_PEN:
|
||||||
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, &candidates_p);
|
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, &candidates_p);
|
||||||
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
|
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
|
||||||
|
@ -1956,6 +1957,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
last_stop_reason = stop_reason::OUT_OF_TOKENS;
|
last_stop_reason = stop_reason::OUT_OF_TOKENS;
|
||||||
stop_sequence.clear();
|
stop_sequence.clear();
|
||||||
special_stop_sequence.clear();
|
special_stop_sequence.clear();
|
||||||
|
dry_repeat_count.clear();
|
||||||
|
dry_sequence_breakers.clear();
|
||||||
|
dry_max_token_repeat.clear();
|
||||||
|
|
||||||
for(int x=0;x<stop_token_max;++x)
|
for(int x=0;x<stop_token_max;++x)
|
||||||
{
|
{
|
||||||
std::string stopper = inputs.stop_sequence[x];
|
std::string stopper = inputs.stop_sequence[x];
|
||||||
|
|
|
@ -12095,6 +12095,7 @@ Current version indicated by LITEVER below.
|
||||||
submit_payload.params.dry_multiplier = localsettings.dry_multiplier;
|
submit_payload.params.dry_multiplier = localsettings.dry_multiplier;
|
||||||
submit_payload.params.dry_base = localsettings.dry_base;
|
submit_payload.params.dry_base = localsettings.dry_base;
|
||||||
submit_payload.params.dry_allowed_length = localsettings.dry_allowed_length;
|
submit_payload.params.dry_allowed_length = localsettings.dry_allowed_length;
|
||||||
|
submit_payload.params.dry_penalty_last_n = localsettings.rep_pen_range;
|
||||||
submit_payload.params.dry_sequence_breakers = JSON.parse(JSON.stringify(localsettings.dry_sequence_breakers));
|
submit_payload.params.dry_sequence_breakers = JSON.parse(JSON.stringify(localsettings.dry_sequence_breakers));
|
||||||
}
|
}
|
||||||
//presence pen and logit bias for OAI and newer kcpp
|
//presence pen and logit bias for OAI and newer kcpp
|
||||||
|
|
|
@ -885,7 +885,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
typical_p = genparams.get('typical', 1.0)
|
typical_p = genparams.get('typical', 1.0)
|
||||||
tfs = genparams.get('tfs', 1.0)
|
tfs = genparams.get('tfs', 1.0)
|
||||||
rep_pen = genparams.get('rep_pen', 1.0)
|
rep_pen = genparams.get('rep_pen', 1.0)
|
||||||
rep_pen_range = genparams.get('rep_pen_range', 256)
|
rep_pen_range = genparams.get('rep_pen_range', 320)
|
||||||
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
|
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
|
||||||
presence_penalty = genparams.get('presence_penalty', 0.0)
|
presence_penalty = genparams.get('presence_penalty', 0.0)
|
||||||
mirostat = genparams.get('mirostat', 0)
|
mirostat = genparams.get('mirostat', 0)
|
||||||
|
@ -894,7 +894,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
||||||
dry_multiplier = genparams.get('dry_multiplier', 0.0)
|
dry_multiplier = genparams.get('dry_multiplier', 0.0)
|
||||||
dry_base = genparams.get('dry_base', 1.75)
|
dry_base = genparams.get('dry_base', 1.75)
|
||||||
dry_allowed_length = genparams.get('dry_allowed_length', 2)
|
dry_allowed_length = genparams.get('dry_allowed_length', 2)
|
||||||
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 0)
|
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 320)
|
||||||
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
|
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
|
||||||
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
|
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
|
||||||
seed = tryparseint(genparams.get('sampler_seed', -1))
|
seed = tryparseint(genparams.get('sampler_seed', -1))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue