added xtc sampler

This commit is contained in:
Concedo 2024-08-21 23:57:15 +08:00
parent 1a7ecd55e6
commit 5bf527a6ae
5 changed files with 121 additions and 8 deletions

View file

@ -501,6 +501,50 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
candidates->size = last_idx;
}
void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng, size_t min_keep)
{
if (xtc_threshold <= 0.0f || xtc_probability <= 0.0f || candidates->size <= 1) {
return;
}
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
float roll = dist(rng);
if(roll>=xtc_probability) //if dice roll fails, skip xtc
{
return;
}
llama_sample_softmax(nullptr, candidates);
//calculate how many tokens cross the xtc threshold
size_t last_idx = candidates->size;
for (size_t i = 0; i < candidates->size; ++i) {
// Go until we reach a value under the threshold
float checkprob = candidates->data[i].p;
if (checkprob < xtc_threshold && i >= min_keep) {
last_idx = i;
break;
}
}
if(last_idx>1) //if there are 2 or more viable candidates
{
// drop all tokens except those above threshold
candidates->size = last_idx;
// then remove all other tokens EXCEPT the least likely one
for (size_t i = 0; i < candidates->size - 1; ++i) {
candidates->data[i].logit = -999.0f; //infinity gets wonky results downstream, this hack works well enough
}
candidates->sorted = false;
} //otherwise xtc does not do anything
// printf("\n\nCandidates: %d, Threshold: %f, LastIdx: %d",candidates->size,xtc_threshold,last_idx);
// printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p);
}
void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& restart_sequences, llama_token_data_array * candidates) {
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) {
return;
@ -822,7 +866,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
{
int id = 0;
std::vector<llama_token_data> candidates;
@ -843,6 +888,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
}
//dry always first as logits cannot be resorted
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
//prefilter to top 5k tokens for improved speed
@ -909,6 +955,8 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
break;
}
}
//xtc always last
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng, 1);
id = sample_token(&candidates_p, rng);
}
@ -2088,6 +2136,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_params->dry_base = inputs.dry_base;
kcpp_params->dry_allowed_length = inputs.dry_allowed_length;
kcpp_params->dry_penalty_last_n = inputs.dry_penalty_last_n;
kcpp_params->xtc_threshold = inputs.xtc_threshold;
kcpp_params->xtc_probability = inputs.xtc_probability;
kcpp_params->dynatemp_range = inputs.dynatemp_range;
kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent;
kcpp_params->n_ctx = inputs.max_context_length;
@ -2662,7 +2712,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta,
kcpp_params->dry_multiplier, kcpp_params->dry_base,
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n,
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n, kcpp_params->xtc_threshold, kcpp_params->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
if (grammar != nullptr) {