add top n sigma sampler from llama.cpp (#1384)

* Add N Sigma Sampler

* update nsigma sampler chain

* xtc position fix

* remove stray newline

---------

Co-authored-by: CasualAutopsy <casual_autopsy@outlook.com>
This commit is contained in:
EquinoxPsychosis 2025-02-21 01:31:42 -05:00 committed by GitHub
parent 5f74ee3c3b
commit 2740af3660
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 61 additions and 4 deletions

View file

@ -82,6 +82,7 @@ struct generation_inputs
const float min_p = 0.0f;
const float typical_p = 0;
const float tfs = 0;
const float nsigma = -1.0f;
const float rep_pen = 0;
const int rep_pen_range = 0;
const float rep_pen_slope = 1.0f;

View file

@ -1428,6 +1428,35 @@ void sampler_typical(llama_token_data_array * cur_p, float p, size_t min_keep) {
cur_p->sorted = false;
}
void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
// find max logit and calculate mean
float nsigmax = cur_p->data[0].logit;
float logits_sum = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > nsigmax) {
nsigmax = cur_p->data[i].logit;
}
logits_sum += cur_p->data[i].logit;
}
float nsigmean = logits_sum / cur_p->size;
// calculate standard deviation
float nsigacc = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
nsigacc += pow(cur_p->data[i].logit - nsigmean, 2);
}
float nsigstd = sqrt(nsigacc / cur_p->size);
//apply mask
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit < nsigmax - (nsigma * nsigstd)) {
cur_p->data[i].logit -= 999.0f;
}
}
sample_softmax(cur_p);
}
void sample_entropy(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor) {
// no need to do anything if there is only one (or zero) candidates
if (cur_p->size <= 1) {
@ -1561,7 +1590,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
{
@ -1584,8 +1613,10 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
}
if (nsigma <= 0.0f){
//dry always first as logits cannot be resorted
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
}
//prefilter to top 3k tokens for improved speed
sample_top_k(&candidates_p, 3000);
@ -1605,6 +1636,25 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
}
}
else if (nsigma > 0.0f)
{
sample_top_k(&candidates_p, top_k);
if (dynatemp_range > 0) {
float dynatemp_min = temp - dynatemp_range;
float dynatemp_max = temp + dynatemp_range;
//do not allow negative values
dynatemp_min = dynatemp_min < 0 ? 0 : dynatemp_min;
dynatemp_max = dynatemp_max < 0 ? 0 : dynatemp_max;
dynatemp_exponent = dynatemp_exponent < 0 ? 0 : dynatemp_exponent;
sample_entropy(&candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor);
} else {
sample_temperature(&candidates_p, temp, smoothing_factor);
}
sample_top_n_sigma(&candidates_p, nsigma);
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng);
id = sample_token(&candidates_p, rng);
}
else
{
for (int i = 0; i < sampler_order.size(); i++)
@ -2999,6 +3049,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->min_p = inputs.min_p;
kcpp_data->typical_p = inputs.typical_p;
kcpp_data->tfs_z = inputs.tfs;
kcpp_data->nsigma = inputs.nsigma;
kcpp_data->temp = inputs.temperature;
kcpp_data->repeat_last_n = inputs.rep_pen_range;
kcpp_data->rep_pen_slope = inputs.rep_pen_slope;
@ -3529,6 +3580,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
const float presence_penalty = kcpp_data->presence_penalty;
const float typical_p = kcpp_data->typical_p;
const float tfs_z = kcpp_data->tfs_z;
const float nsigma = kcpp_data->nsigma;
const float dynatemp_range = kcpp_data->dynatemp_range;
const float dynatemp_exponent = kcpp_data->dynatemp_exponent;
const float smoothing_factor = kcpp_data->smoothing_factor;
@ -3624,7 +3676,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
top_k, top_a, top_p, min_p, typical_p, tfs_z, nsigma, temp, rng,
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
kcpp_data->dry_multiplier, kcpp_data->dry_base,
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,

View file

@ -194,6 +194,7 @@ class generation_inputs(ctypes.Structure):
("min_p", ctypes.c_float),
("typical_p", ctypes.c_float),
("tfs", ctypes.c_float),
("nsigma", ctypes.c_float),
("rep_pen", ctypes.c_float),
("rep_pen_range", ctypes.c_int),
("rep_pen_slope", ctypes.c_float),
@ -1116,6 +1117,7 @@ def generate(genparams, stream_flag=False):
min_p = float(genparams.get('min_p', 0.0))
typical_p = float(genparams.get('typical', 1.0))
tfs = float(genparams.get('tfs', 1.0))
nsigma = float(genparams.get('nsigma', -1.0))
rep_pen = float(genparams.get('rep_pen', 1.0))
rep_pen_range = int(genparams.get('rep_pen_range', 320))
rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0))
@ -1182,6 +1184,7 @@ def generate(genparams, stream_flag=False):
inputs.min_p = min_p
inputs.typical_p = typical_p
inputs.tfs = tfs
inputs.nsigma = nsigma
inputs.rep_pen = rep_pen
inputs.rep_pen_range = rep_pen_range
inputs.rep_pen_slope = rep_pen_slope

View file

@ -29,6 +29,7 @@ struct kcpp_params {
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.0f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float nsigma = -1.00f; // -1.0 - disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float smoothing_factor = 0.00f; // 0.00 = disabled