Merge remote-tracking branch 'ycros/improve-sampler-api-access' into concedo_experimental

This commit is contained in:
Concedo 2023-07-04 16:38:32 +08:00
commit 784628a2be
3 changed files with 84 additions and 8 deletions

View file

@ -219,8 +219,16 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
candidates->size = last_idx;
}
void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p)
{
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
llama_sample_repetition_penalty(nullptr, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, rep_pen);
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta)
int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX])
{
int id = 0;
std::vector<llama_token_data> candidates;
@ -231,11 +239,11 @@ int mirostat, float mirostat_tau, float mirostat_eta)
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// Apply penalties
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
llama_sample_repetition_penalty(nullptr, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, rep_pen);
// Run this except for when we are going to do the sampler reordering case below
if (temp <= 0 || mirostat > 0 || sampler_len == 0)
{
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
}
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@ -261,6 +269,37 @@ int mirostat, float mirostat_tau, float mirostat_eta)
llama_sample_temperature(nullptr, &candidates_p, temp);
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else if (sampler_len > 0)
{
for (int i = 0; i < sampler_len; i++) {
switch (sampler_order[i]) {
case KCPP_SAMPLER_TOP_K:
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
break;
case KCPP_SAMPLER_TOP_A:
sample_top_a(&candidates_p,top_a,1);
break;
case KCPP_SAMPLER_TOP_P:
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
break;
case KCPP_SAMPLER_TFS:
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
break;
case KCPP_SAMPLER_TYP:
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
break;
case KCPP_SAMPLER_TEMP:
llama_sample_temperature(nullptr, &candidates_p, temp);
break;
case KCPP_SAMPLER_REP_PEN:
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
break;
default:
break;
}
}
id = sample_token(&candidates_p, rng);
}
else
{
// Temperature sampling
@ -1235,7 +1274,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
params.mirostat,params.mirostat_tau,params.mirostat_eta);
params.mirostat, params.mirostat_tau, params.mirostat_eta,
inputs.sampler_len, inputs.sampler_order);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);