mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-15 19:39:42 +00:00
added xtc sampler
This commit is contained in:
parent
1a7ecd55e6
commit
5bf527a6ae
5 changed files with 121 additions and 8 deletions
|
@ -501,6 +501,50 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
|
|||
candidates->size = last_idx;
|
||||
}
|
||||
|
||||
void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng, size_t min_keep)
|
||||
{
|
||||
if (xtc_threshold <= 0.0f || xtc_probability <= 0.0f || candidates->size <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||
float roll = dist(rng);
|
||||
if(roll>=xtc_probability) //if dice roll fails, skip xtc
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
//calculate how many tokens cross the xtc threshold
|
||||
size_t last_idx = candidates->size;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
// Go until we reach a value under the threshold
|
||||
float checkprob = candidates->data[i].p;
|
||||
if (checkprob < xtc_threshold && i >= min_keep) {
|
||||
last_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(last_idx>1) //if there are 2 or more viable candidates
|
||||
{
|
||||
// drop all tokens except those above threshold
|
||||
candidates->size = last_idx;
|
||||
|
||||
// then remove all other tokens EXCEPT the least likely one
|
||||
for (size_t i = 0; i < candidates->size - 1; ++i) {
|
||||
candidates->data[i].logit = -999.0f; //infinity gets wonky results downstream, this hack works well enough
|
||||
}
|
||||
candidates->sorted = false;
|
||||
|
||||
} //otherwise xtc does not do anything
|
||||
|
||||
// printf("\n\nCandidates: %d, Threshold: %f, LastIdx: %d",candidates->size,xtc_threshold,last_idx);
|
||||
// printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p);
|
||||
|
||||
}
|
||||
|
||||
void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& restart_sequences, llama_token_data_array * candidates) {
|
||||
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) {
|
||||
return;
|
||||
|
@ -822,7 +866,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
|||
}
|
||||
|
||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
||||
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
|
||||
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
|
||||
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
|
||||
{
|
||||
int id = 0;
|
||||
std::vector<llama_token_data> candidates;
|
||||
|
@ -843,6 +888,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
|
|||
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
|
||||
}
|
||||
|
||||
//dry always first as logits cannot be resorted
|
||||
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
|
||||
|
||||
//prefilter to top 5k tokens for improved speed
|
||||
|
@ -909,6 +955,8 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
|
|||
break;
|
||||
}
|
||||
}
|
||||
//xtc always last
|
||||
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng, 1);
|
||||
id = sample_token(&candidates_p, rng);
|
||||
}
|
||||
|
||||
|
@ -2088,6 +2136,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
kcpp_params->dry_base = inputs.dry_base;
|
||||
kcpp_params->dry_allowed_length = inputs.dry_allowed_length;
|
||||
kcpp_params->dry_penalty_last_n = inputs.dry_penalty_last_n;
|
||||
kcpp_params->xtc_threshold = inputs.xtc_threshold;
|
||||
kcpp_params->xtc_probability = inputs.xtc_probability;
|
||||
kcpp_params->dynatemp_range = inputs.dynatemp_range;
|
||||
kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent;
|
||||
kcpp_params->n_ctx = inputs.max_context_length;
|
||||
|
@ -2662,7 +2712,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
|
||||
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta,
|
||||
kcpp_params->dry_multiplier, kcpp_params->dry_base,
|
||||
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n,
|
||||
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n, kcpp_params->xtc_threshold, kcpp_params->xtc_probability,
|
||||
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
|
||||
|
||||
if (grammar != nullptr) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue