add smoothing curve, not tested

This commit is contained in:
LostRuins Concedo 2025-11-17 23:07:35 +08:00
parent 3fe0e39b62
commit 281542aa0d
5 changed files with 203 additions and 70 deletions

File diff suppressed because one or more lines are too long

View file

@ -119,6 +119,7 @@ struct generation_inputs
const float dynatemp_range = 0.0f;
const float dynatemp_exponent = 1.0f;
const float smoothing_factor = 0.0f;
const float smoothing_curve = 1.0f;
const float dry_multiplier = 0.0f;
const float dry_base = 0.0f;
const int dry_allowed_length = 0;

View file

@ -1477,7 +1477,7 @@ void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
sample_softmax(cur_p);
}
void sample_entropy(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor) {
void sample_entropy(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor, float smoothing_curve) {
// no need to do anything if there is only one (or zero) candidates
if (cur_p->size <= 1) {
return;
@ -1525,19 +1525,20 @@ void sample_entropy(llama_token_data_array * cur_p, float min_temp, float max_te
// Only apply smoothing if smoothing_factor is > 0. Do not change base implementation otherwise.
if (smoothing_factor > 0 && cur_p->size > 1) {
sample_softmax(cur_p);
float h = cur_p->data[0].logit; // Find the maximum logit for h to be added after the transformation
// Apply quadratic transformation using the smoothing_factor
for (size_t i = 0; i < cur_p->size; ++i)
{
float h = cur_p->data[0].logit; // Find the maximum logit for h to be added after the transformation
// Apply the modified quadratic transformation using the smoothing_factor and smoothing_curve
for (size_t i = 0; i < cur_p->size; ++i) {
float logit_shifted = cur_p->data[i].logit - h;
cur_p->data[i].logit = -smoothing_factor * logit_shifted * logit_shifted + h;
float k = (3 - smoothing_curve) / 2;
float s = (smoothing_curve - 1) / 2;
cur_p->data[i].logit = -(k * smoothing_factor * logit_shifted * logit_shifted) + (s * smoothing_factor * logit_shifted * logit_shifted * logit_shifted) + h;
}
sample_softmax(cur_p);
}
}
void sample_temperature(llama_token_data_array * candidates_p, float temp, float smoothing_factor)
void sample_temperature(llama_token_data_array * candidates_p, float temp, float smoothing_factor, float smoothing_curve)
{
bool isgreedy = false;
if (temp <= 0)
@ -1555,11 +1556,12 @@ void sample_temperature(llama_token_data_array * candidates_p, float temp, float
if (smoothing_factor > 0 && candidates_p->size > 1) {
sample_softmax(candidates_p);
float h = candidates_p->data[0].logit; // Find the maximum logit for h to be added after the transformation
// Apply quadratic transformation using the smoothing_factor
for (size_t i = 0; i < candidates_p->size; ++i)
{
// Apply the modified quadratic transformation using the smoothing_factor and smoothing_curve
for (size_t i = 0; i < candidates_p->size; ++i) {
float logit_shifted = candidates_p->data[i].logit - h;
candidates_p->data[i].logit = -smoothing_factor * logit_shifted * logit_shifted + h;
float k = (3 - smoothing_curve) / 2;
float s = (smoothing_curve - 1) / 2;
candidates_p->data[i].logit = -(k * smoothing_factor * logit_shifted * logit_shifted) + (s * smoothing_factor * logit_shifted * logit_shifted * logit_shifted) + h;
}
sample_softmax(candidates_p);
}
@ -1645,7 +1647,7 @@ void sample_guidance(struct llama_context * ctx, struct llama_context * guidance
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor, float smoothing_curve)
{
// printf("SampleLogits called with: n_ctx=%d, n_vocab=%d, rep_pen_range=%d, rep_pen=%f, rep_pen_slope=%f, presence_penalty=%f, top_k=%f, top_a=%f, top_p=%f, min_p=%f, typical_p=%f, tfs=%f, nsigma=%f, temp=%f, mirostat=%d, mirostat_tau=%f, mirostat_eta=%f, dry_multiplier=%f, dry_base=%f, dry_allowed_length=%d, dry_penalty_last_n=%d, xtc_threshold=%f, xtc_probability=%f, sampler_order_size=%zu, dynatemp_range=%f, dynatemp_exponent=%f, smoothing_factor=%f\n",
// n_ctx, n_vocab, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, top_k, top_a, top_p, min_p, typical_p, tfs, nsigma, temp, mirostat, mirostat_tau, mirostat_eta, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, xtc_threshold, xtc_probability, sampler_order.size(), dynatemp_range, dynatemp_exponent, smoothing_factor);
@ -1689,7 +1691,7 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, &candidates_p);
sample_temperature(&candidates_p, temp, smoothing_factor);
sample_temperature(&candidates_p, temp, smoothing_factor, smoothing_curve);
if (mirostat == 1)
{
id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
@ -1730,11 +1732,11 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
dynatemp_min = dynatemp_min<0?0:dynatemp_min;
dynatemp_max = dynatemp_max<0?0:dynatemp_max;
dynatemp_exponent = dynatemp_exponent<0?0:dynatemp_exponent;
sample_entropy(&candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor);
sample_entropy(&candidates_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor, smoothing_curve);
}
else
{
sample_temperature(&candidates_p, temp, smoothing_factor);
sample_temperature(&candidates_p, temp, smoothing_factor, smoothing_curve);
}
if (nsigma > 0.0f)
{
@ -3441,6 +3443,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->dynatemp_exponent = inputs.dynatemp_exponent;
kcpp_data->n_ctx = inputs.max_context_length;
kcpp_data->smoothing_factor = inputs.smoothing_factor;
kcpp_data->smoothing_curve = inputs.smoothing_curve;
// Parse dry sequence breakers / restart sequences
kcpp_data->dry_sequence_breakers.clear();
@ -4114,6 +4117,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
const float dynatemp_range = kcpp_data->dynatemp_range;
const float dynatemp_exponent = kcpp_data->dynatemp_exponent;
const float smoothing_factor = kcpp_data->smoothing_factor;
const float smoothing_curve = kcpp_data->smoothing_curve;
if (!startedsampling)
{
@ -4220,7 +4224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
kcpp_data->dry_multiplier, kcpp_data->dry_base,
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor, smoothing_curve);
if(draft_used)
{

View file

@ -260,6 +260,7 @@ class generation_inputs(ctypes.Structure):
("dynatemp_range", ctypes.c_float),
("dynatemp_exponent", ctypes.c_float),
("smoothing_factor", ctypes.c_float),
("smoothing_curve", ctypes.c_float),
("dry_multiplier", ctypes.c_float),
("dry_base", ctypes.c_float),
("dry_allowed_length", ctypes.c_int),
@ -1557,6 +1558,7 @@ def generate(genparams, stream_flag=False):
dynatemp_range = tryparsefloat(genparams.get('dynatemp_range', 0.0),0.0)
dynatemp_exponent = tryparsefloat(genparams.get('dynatemp_exponent', 1.0),1.0)
smoothing_factor = tryparsefloat(genparams.get('smoothing_factor', 0.0),0.0)
smoothing_curve = tryparsefloat(genparams.get('smoothing_curve', 1.0),1.0)
logit_biases = genparams.get('logit_bias', {})
render_special = genparams.get('render_special', False)
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
@ -1619,6 +1621,7 @@ def generate(genparams, stream_flag=False):
inputs.dynatemp_range = dynatemp_range
inputs.dynatemp_exponent = dynatemp_exponent
inputs.smoothing_factor = smoothing_factor
inputs.smoothing_curve = smoothing_curve
inputs.grammar = grammar.encode("UTF-8")
inputs.grammar_retain_state = grammar_retain_state
inputs.allow_eos_token = not ban_eos_token

View file

@ -33,6 +33,7 @@ struct kcpp_params {
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float smoothing_factor = 0.00f; // 0.00 = disabled
float smoothing_curve = 1.00f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float rep_pen_slope = 1.0f;