rename power law sampler to adaptive p

This commit is contained in:
Concedo 2025-12-27 17:50:58 +08:00
parent 445aad5e00
commit 6548645aaa
5 changed files with 38 additions and 38 deletions

File diff suppressed because one or more lines are too long

View file

@ -123,7 +123,7 @@ struct generation_inputs
const float dynatemp_exponent = 1.0f;
const float smoothing_factor = 0.0f;
const float smoothing_curve = 1.0f;
const float power_law_target = -1.0f;
const float adaptive_target = -1.0f;
const float dry_multiplier = 0.0f;
const float dry_base = 0.0f;
const int dry_allowed_length = 0;

View file

@ -125,8 +125,8 @@ static std::vector<gpt_vocab::id> current_context_tokens;
static size_t mem_per_token = 0;
static std::vector<float> logits;
static std::vector<int> smartcontext;
static float power_law_weighted_sum = 0; //power law sampling state vars
static float power_law_total_weight = 0;
static float adaptive_p_weighted_sum = 0; //adaptive p sampling state vars
static float adaptive_p_total_weight = 0;
static std::vector<std::string> stop_sequence;
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
static std::vector<std::string> banned_tokens;
@ -1267,7 +1267,7 @@ void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float pe
}
}
void sample_power_law(
void sample_adaptive_p(
float target, // desired average probability (0..1), <=0 disables
float & weighted_sum, // persistent EMA state
float & total_weight, // persistent EMA state
@ -1290,7 +1290,7 @@ llama_token_data_array * cur_p)
// compute the adapted target probability for the current sampling step
float computed_target = std::clamp(total_weight == 0.0f ? target : 2.0f * target - (weighted_sum / total_weight),0.0f, 1.0f);
// power law transform
// adaptive p transform
const float k = 4.0f; // controls sharpness
for (size_t i = 0; i < cur_p->size; ++i) {
float dist = (cur_p->data[i].p - computed_target) * inv_width;
@ -1301,18 +1301,18 @@ llama_token_data_array * cur_p)
cur_p->sorted = false;
sample_softmax(cur_p);
//update EMA history AFTER sampling, update_power_law_history(original_prob[idx])
//update EMA history AFTER sampling, update_adaptive_p_history(original_prob[idx])
}
inline void power_law_update_history(float selected_token_prob, float & weighted_sum, float & total_weight) {
inline void adaptive_p_update_history(float selected_token_prob, float & weighted_sum, float & total_weight) {
// decay controls how quickly history influence fades (0.0 to 0.99)
// lower values = faster adaptation, more reactive to recent tokens
// higher values = slower adaptation, more stable over time
// effective history length ≈ 1/(1-decay) tokens
// example: decay=0.5 --> ~2 tokens; decay=0.9 --> ~10 tokens; decay=0.95 --> ~20 tokens
// keep <= 0.99 to prevent unbounded accumulation
const float power_law_decay = 0.90f;
weighted_sum = selected_token_prob + power_law_decay * weighted_sum;
total_weight = 1.0f + power_law_decay * total_weight;
const float adaptive_p_decay = 0.90f;
weighted_sum = selected_token_prob + adaptive_p_decay * weighted_sum;
total_weight = 1.0f + adaptive_p_decay * total_weight;
}
@ -1741,7 +1741,7 @@ void sample_guidance(struct llama_context * ctx, struct llama_context * guidance
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor, float smoothing_curve, float power_law_target)
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor, float smoothing_curve, float adaptive_target)
{
// printf("SampleLogits called with: n_ctx=%d, n_vocab=%d, rep_pen_range=%d, rep_pen=%f, rep_pen_slope=%f, presence_penalty=%f, top_k=%f, top_a=%f, top_p=%f, min_p=%f, typical_p=%f, tfs=%f, nsigma=%f, temp=%f, mirostat=%d, mirostat_tau=%f, mirostat_eta=%f, dry_multiplier=%f, dry_base=%f, dry_allowed_length=%d, dry_penalty_last_n=%d, xtc_threshold=%f, xtc_probability=%f, sampler_order_size=%zu, dynatemp_range=%f, dynatemp_exponent=%f, smoothing_factor=%f\n",
// n_ctx, n_vocab, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, top_k, top_a, top_p, min_p, typical_p, tfs, nsigma, temp, mirostat, mirostat_tau, mirostat_eta, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, xtc_threshold, xtc_probability, sampler_order.size(), dynatemp_range, dynatemp_exponent, smoothing_factor);
@ -1847,8 +1847,8 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
}
//xtc always last
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng);
//power law must be last, it messes up all probs
sample_power_law(power_law_target, power_law_weighted_sum, power_law_total_weight, &candidates_p);
//adaptive p must be last, it messes up all probs
sample_adaptive_p(adaptive_target, adaptive_p_weighted_sum, adaptive_p_total_weight, &candidates_p);
id = sample_token(&candidates_p, rng);
}
@ -3444,8 +3444,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
power_law_weighted_sum = 0;
power_law_total_weight = 0;
adaptive_p_weighted_sum = 0;
adaptive_p_total_weight = 0;
//handle custom token bans and antislop phrase banning
banned_phrases.clear();
@ -3655,7 +3655,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->n_ctx = inputs.max_context_length;
kcpp_data->smoothing_factor = inputs.smoothing_factor;
kcpp_data->smoothing_curve = inputs.smoothing_curve;
kcpp_data->power_law_target = inputs.power_law_target;
kcpp_data->adaptive_target = inputs.adaptive_target;
// Parse dry sequence breakers / restart sequences
kcpp_data->dry_sequence_breakers.clear();
@ -4484,7 +4484,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
const float dynatemp_exponent = kcpp_data->dynatemp_exponent;
const float smoothing_factor = kcpp_data->smoothing_factor;
const float smoothing_curve = kcpp_data->smoothing_curve;
const float power_law_target = kcpp_data->power_law_target;
const float adaptive_target = kcpp_data->adaptive_target;
if (!startedsampling)
{
@ -4562,9 +4562,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
lowestLogit = LowestLogit(logits);
}
//if power law sampling is used, we need to cache the original probabilities
//if adaptive p sampling is used, we need to cache the original probabilities
std::vector<llama_token_data> original_candidates;
if(power_law_target > 0.0f)
if(adaptive_target > 0.0f)
{
original_candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@ -4618,11 +4618,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
kcpp_data->dry_multiplier, kcpp_data->dry_base,
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor, smoothing_curve, power_law_target);
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor, smoothing_curve, adaptive_target);
if (power_law_target > 0.0f) {
if (adaptive_target > 0.0f) {
float original_prob = original_candidates[id].p;
power_law_update_history(original_prob, power_law_weighted_sum, power_law_total_weight);
adaptive_p_update_history(original_prob, adaptive_p_weighted_sum, adaptive_p_total_weight);
}
if(draft_used)

View file

@ -265,7 +265,7 @@ class generation_inputs(ctypes.Structure):
("dynatemp_exponent", ctypes.c_float),
("smoothing_factor", ctypes.c_float),
("smoothing_curve", ctypes.c_float),
("power_law_target", ctypes.c_float),
("adaptive_target", ctypes.c_float),
("dry_multiplier", ctypes.c_float),
("dry_base", ctypes.c_float),
("dry_allowed_length", ctypes.c_int),
@ -1604,8 +1604,8 @@ def generate(genparams, stream_flag=False):
dynatemp_exponent = tryparsefloat(genparams.get('dynatemp_exponent', 1.0),1.0)
smoothing_factor = tryparsefloat(genparams.get('smoothing_factor', 0.0),0.0)
smoothing_curve = tryparsefloat(genparams.get('smoothing_curve', 1.0),1.0)
power_law_target = tryparsefloat(genparams.get('power_law_target', -1.0),-1.0)
if power_law_target>0 and min_p<=0 and top_p>=1.0: #power law sampler requires a truncation sampler first, force a tiny min-p
adaptive_target = tryparsefloat(genparams.get('adaptive_target', -1.0),-1.0)
if adaptive_target>0 and min_p<=0 and top_p>=1.0: #adaptive p sampler requires a truncation sampler first, force a tiny min-p
min_p = 0.01
logit_biases = genparams.get('logit_bias', {})
render_special = genparams.get('render_special', False)
@ -1670,7 +1670,7 @@ def generate(genparams, stream_flag=False):
inputs.dynatemp_exponent = dynatemp_exponent
inputs.smoothing_factor = smoothing_factor
inputs.smoothing_curve = smoothing_curve
inputs.power_law_target = power_law_target
inputs.adaptive_target = adaptive_target
inputs.grammar = grammar.encode("UTF-8")
inputs.grammar_retain_state = grammar_retain_state
inputs.allow_eos_token = not ban_eos_token

View file

@ -50,7 +50,7 @@ struct kcpp_params {
float xtc_probability = 0;
float dynatemp_range = 0.0f; // enables DynaTemp if neq 0. dynatemp_min = temperature - dt_range, dynatemp_max = temperature + dt_range
float dynatemp_exponent = 1.0f;
float power_law_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
float adaptive_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
std::string model_filename = ""; // model path
std::string prompt = "";