add presence penalty

This commit is contained in:
Concedo 2023-12-19 23:18:56 +08:00
parent da2db0302c
commit 3f863eed72
3 changed files with 18 additions and 9 deletions

View file

@ -386,7 +386,7 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
candidates->size = last_idx;
}
void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p)
void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float presence_penalty, llama_token_data_array * candidates_p)
{
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
@ -414,6 +414,8 @@ void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_dat
} else {
candidates->data[i].logit /= penalty;
}
candidates->data[i].logit -= presence_penalty;
}
candidates->sorted = false;
@ -474,7 +476,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
{
int id = 0;
@ -494,7 +496,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
{
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
sample_temperature(&candidates_p, temp);
if (mirostat == 1)
{
@ -531,7 +533,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
sample_temperature(&candidates_p, temp);
break;
case KCPP_SAMPLER_REP_PEN:
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
break;
default:
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
@ -1442,6 +1444,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
params.temp = inputs.temperature;
params.repeat_last_n = inputs.rep_pen_range;
params.repeat_penalty = inputs.rep_pen;
params.presence_penalty = inputs.presence_penalty;
params.mirostat = inputs.mirostat;
params.mirostat_eta = inputs.mirostat_eta;
params.mirostat_tau = inputs.mirostat_tau;
@ -1836,6 +1839,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
const float temp = params.temp;
const float top_a = inputs.top_a;
const float repeat_penalty = params.repeat_penalty;
const float presence_penalty = params.presence_penalty;
const float typical_p = params.typical_p;
const float tfs_z = params.tfs_z;
@ -1891,7 +1895,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, presence_penalty,
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order, grammar);