mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-23 04:19:08 +00:00
rename power law sampler to adaptive p
This commit is contained in:
parent
445aad5e00
commit
6548645aaa
5 changed files with 38 additions and 38 deletions
File diff suppressed because one or more lines are too long
2
expose.h
2
expose.h
|
|
@ -123,7 +123,7 @@ struct generation_inputs
|
|||
const float dynatemp_exponent = 1.0f;
|
||||
const float smoothing_factor = 0.0f;
|
||||
const float smoothing_curve = 1.0f;
|
||||
const float power_law_target = -1.0f;
|
||||
const float adaptive_target = -1.0f;
|
||||
const float dry_multiplier = 0.0f;
|
||||
const float dry_base = 0.0f;
|
||||
const int dry_allowed_length = 0;
|
||||
|
|
|
|||
|
|
@ -125,8 +125,8 @@ static std::vector<gpt_vocab::id> current_context_tokens;
|
|||
static size_t mem_per_token = 0;
|
||||
static std::vector<float> logits;
|
||||
static std::vector<int> smartcontext;
|
||||
static float power_law_weighted_sum = 0; //power law sampling state vars
|
||||
static float power_law_total_weight = 0;
|
||||
static float adaptive_p_weighted_sum = 0; //adaptive p sampling state vars
|
||||
static float adaptive_p_total_weight = 0;
|
||||
static std::vector<std::string> stop_sequence;
|
||||
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
||||
static std::vector<std::string> banned_tokens;
|
||||
|
|
@ -1267,7 +1267,7 @@ void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float pe
|
|||
}
|
||||
}
|
||||
|
||||
void sample_power_law(
|
||||
void sample_adaptive_p(
|
||||
float target, // desired average probability (0..1), <=0 disables
|
||||
float & weighted_sum, // persistent EMA state
|
||||
float & total_weight, // persistent EMA state
|
||||
|
|
@ -1290,7 +1290,7 @@ llama_token_data_array * cur_p)
|
|||
// compute the adapted target probability for the current sampling step
|
||||
float computed_target = std::clamp(total_weight == 0.0f ? target : 2.0f * target - (weighted_sum / total_weight),0.0f, 1.0f);
|
||||
|
||||
// power law transform
|
||||
// adaptive p transform
|
||||
const float k = 4.0f; // controls sharpness
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float dist = (cur_p->data[i].p - computed_target) * inv_width;
|
||||
|
|
@ -1301,18 +1301,18 @@ llama_token_data_array * cur_p)
|
|||
cur_p->sorted = false;
|
||||
sample_softmax(cur_p);
|
||||
|
||||
//update EMA history AFTER sampling, update_power_law_history(original_prob[idx])
|
||||
//update EMA history AFTER sampling, update_adaptive_p_history(original_prob[idx])
|
||||
}
|
||||
inline void power_law_update_history(float selected_token_prob, float & weighted_sum, float & total_weight) {
|
||||
inline void adaptive_p_update_history(float selected_token_prob, float & weighted_sum, float & total_weight) {
|
||||
// decay controls how quickly history influence fades (0.0 to 0.99)
|
||||
// lower values = faster adaptation, more reactive to recent tokens
|
||||
// higher values = slower adaptation, more stable over time
|
||||
// effective history length ≈ 1/(1-decay) tokens
|
||||
// example: decay=0.5 --> ~2 tokens; decay=0.9 --> ~10 tokens; decay=0.95 --> ~20 tokens
|
||||
// keep <= 0.99 to prevent unbounded accumulation
|
||||
const float power_law_decay = 0.90f;
|
||||
weighted_sum = selected_token_prob + power_law_decay * weighted_sum;
|
||||
total_weight = 1.0f + power_law_decay * total_weight;
|
||||
const float adaptive_p_decay = 0.90f;
|
||||
weighted_sum = selected_token_prob + adaptive_p_decay * weighted_sum;
|
||||
total_weight = 1.0f + adaptive_p_decay * total_weight;
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1741,7 +1741,7 @@ void sample_guidance(struct llama_context * ctx, struct llama_context * guidance
|
|||
|
||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
|
||||
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
|
||||
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor, float smoothing_curve, float power_law_target)
|
||||
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor, float smoothing_curve, float adaptive_target)
|
||||
{
|
||||
// printf("SampleLogits called with: n_ctx=%d, n_vocab=%d, rep_pen_range=%d, rep_pen=%f, rep_pen_slope=%f, presence_penalty=%f, top_k=%f, top_a=%f, top_p=%f, min_p=%f, typical_p=%f, tfs=%f, nsigma=%f, temp=%f, mirostat=%d, mirostat_tau=%f, mirostat_eta=%f, dry_multiplier=%f, dry_base=%f, dry_allowed_length=%d, dry_penalty_last_n=%d, xtc_threshold=%f, xtc_probability=%f, sampler_order_size=%zu, dynatemp_range=%f, dynatemp_exponent=%f, smoothing_factor=%f\n",
|
||||
// n_ctx, n_vocab, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, top_k, top_a, top_p, min_p, typical_p, tfs, nsigma, temp, mirostat, mirostat_tau, mirostat_eta, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, xtc_threshold, xtc_probability, sampler_order.size(), dynatemp_range, dynatemp_exponent, smoothing_factor);
|
||||
|
|
@ -1847,8 +1847,8 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
|
|||
}
|
||||
//xtc always last
|
||||
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng);
|
||||
//power law must be last, it messes up all probs
|
||||
sample_power_law(power_law_target, power_law_weighted_sum, power_law_total_weight, &candidates_p);
|
||||
//adaptive p must be last, it messes up all probs
|
||||
sample_adaptive_p(adaptive_target, adaptive_p_weighted_sum, adaptive_p_total_weight, &candidates_p);
|
||||
id = sample_token(&candidates_p, rng);
|
||||
}
|
||||
|
||||
|
|
@ -3444,8 +3444,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
|
||||
power_law_weighted_sum = 0;
|
||||
power_law_total_weight = 0;
|
||||
adaptive_p_weighted_sum = 0;
|
||||
adaptive_p_total_weight = 0;
|
||||
|
||||
//handle custom token bans and antislop phrase banning
|
||||
banned_phrases.clear();
|
||||
|
|
@ -3655,7 +3655,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
kcpp_data->n_ctx = inputs.max_context_length;
|
||||
kcpp_data->smoothing_factor = inputs.smoothing_factor;
|
||||
kcpp_data->smoothing_curve = inputs.smoothing_curve;
|
||||
kcpp_data->power_law_target = inputs.power_law_target;
|
||||
kcpp_data->adaptive_target = inputs.adaptive_target;
|
||||
|
||||
// Parse dry sequence breakers / restart sequences
|
||||
kcpp_data->dry_sequence_breakers.clear();
|
||||
|
|
@ -4484,7 +4484,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
const float dynatemp_exponent = kcpp_data->dynatemp_exponent;
|
||||
const float smoothing_factor = kcpp_data->smoothing_factor;
|
||||
const float smoothing_curve = kcpp_data->smoothing_curve;
|
||||
const float power_law_target = kcpp_data->power_law_target;
|
||||
const float adaptive_target = kcpp_data->adaptive_target;
|
||||
|
||||
if (!startedsampling)
|
||||
{
|
||||
|
|
@ -4562,9 +4562,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
lowestLogit = LowestLogit(logits);
|
||||
}
|
||||
|
||||
//if power law sampling is used, we need to cache the original probabilities
|
||||
//if adaptive p sampling is used, we need to cache the original probabilities
|
||||
std::vector<llama_token_data> original_candidates;
|
||||
if(power_law_target > 0.0f)
|
||||
if(adaptive_target > 0.0f)
|
||||
{
|
||||
original_candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
|
|
@ -4618,11 +4618,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
|
||||
kcpp_data->dry_multiplier, kcpp_data->dry_base,
|
||||
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
|
||||
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor, smoothing_curve, power_law_target);
|
||||
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor, smoothing_curve, adaptive_target);
|
||||
|
||||
if (power_law_target > 0.0f) {
|
||||
if (adaptive_target > 0.0f) {
|
||||
float original_prob = original_candidates[id].p;
|
||||
power_law_update_history(original_prob, power_law_weighted_sum, power_law_total_weight);
|
||||
adaptive_p_update_history(original_prob, adaptive_p_weighted_sum, adaptive_p_total_weight);
|
||||
}
|
||||
|
||||
if(draft_used)
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ class generation_inputs(ctypes.Structure):
|
|||
("dynatemp_exponent", ctypes.c_float),
|
||||
("smoothing_factor", ctypes.c_float),
|
||||
("smoothing_curve", ctypes.c_float),
|
||||
("power_law_target", ctypes.c_float),
|
||||
("adaptive_target", ctypes.c_float),
|
||||
("dry_multiplier", ctypes.c_float),
|
||||
("dry_base", ctypes.c_float),
|
||||
("dry_allowed_length", ctypes.c_int),
|
||||
|
|
@ -1604,8 +1604,8 @@ def generate(genparams, stream_flag=False):
|
|||
dynatemp_exponent = tryparsefloat(genparams.get('dynatemp_exponent', 1.0),1.0)
|
||||
smoothing_factor = tryparsefloat(genparams.get('smoothing_factor', 0.0),0.0)
|
||||
smoothing_curve = tryparsefloat(genparams.get('smoothing_curve', 1.0),1.0)
|
||||
power_law_target = tryparsefloat(genparams.get('power_law_target', -1.0),-1.0)
|
||||
if power_law_target>0 and min_p<=0 and top_p>=1.0: #power law sampler requires a truncation sampler first, force a tiny min-p
|
||||
adaptive_target = tryparsefloat(genparams.get('adaptive_target', -1.0),-1.0)
|
||||
if adaptive_target>0 and min_p<=0 and top_p>=1.0: #adaptive p sampler requires a truncation sampler first, force a tiny min-p
|
||||
min_p = 0.01
|
||||
logit_biases = genparams.get('logit_bias', {})
|
||||
render_special = genparams.get('render_special', False)
|
||||
|
|
@ -1670,7 +1670,7 @@ def generate(genparams, stream_flag=False):
|
|||
inputs.dynatemp_exponent = dynatemp_exponent
|
||||
inputs.smoothing_factor = smoothing_factor
|
||||
inputs.smoothing_curve = smoothing_curve
|
||||
inputs.power_law_target = power_law_target
|
||||
inputs.adaptive_target = adaptive_target
|
||||
inputs.grammar = grammar.encode("UTF-8")
|
||||
inputs.grammar_retain_state = grammar_retain_state
|
||||
inputs.allow_eos_token = not ban_eos_token
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ struct kcpp_params {
|
|||
float xtc_probability = 0;
|
||||
float dynatemp_range = 0.0f; // enables DynaTemp if neq 0. dynatemp_min = temperature - dt_range, dynatemp_max = temperature + dt_range
|
||||
float dynatemp_exponent = 1.0f;
|
||||
float power_law_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
|
||||
float adaptive_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
|
||||
|
||||
std::string model_filename = ""; // model path
|
||||
std::string prompt = "";
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue