xtc fixes

This commit is contained in:
Concedo 2024-08-22 23:18:46 +08:00
parent 0b96097439
commit cca3c4c78b

View file

@ -501,9 +501,9 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
candidates->size = last_idx; candidates->size = last_idx;
} }
void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng, size_t min_keep) void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng)
{ {
if (xtc_threshold <= 0.0f || xtc_probability <= 0.0f || candidates->size <= 1) { if (xtc_threshold > 0.5f || xtc_probability <= 0.0f || candidates->size <= 1) {
return; return;
} }
@ -521,7 +521,7 @@ void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
// Go until we reach a value under the threshold // Go until we reach a value under the threshold
float checkprob = candidates->data[i].p; float checkprob = candidates->data[i].p;
if (checkprob < xtc_threshold && i >= min_keep) { if (checkprob < xtc_threshold) {
last_idx = i; last_idx = i;
break; break;
} }
@ -529,11 +529,8 @@ void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float
if(last_idx>1) //if there are 2 or more viable candidates if(last_idx>1) //if there are 2 or more viable candidates
{ {
// drop all tokens except those above threshold // then remove all other tokens above threshold EXCEPT the least likely one
candidates->size = last_idx; for (size_t i = 0; i < last_idx - 1; ++i) {
// then remove all other tokens EXCEPT the least likely one
for (size_t i = 0; i < candidates->size - 1; ++i) {
candidates->data[i].logit -= 999.0f; //infinity gets wonky results downstream, this hack works well enough candidates->data[i].logit -= 999.0f; //infinity gets wonky results downstream, this hack works well enough
} }
candidates->sorted = false; candidates->sorted = false;
@ -956,7 +953,7 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
} }
} }
//xtc always last //xtc always last
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng, 1); sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng);
id = sample_token(&candidates_p, rng); id = sample_token(&candidates_p, rng);
} }