From cca3c4c78b3d585741b2d61dac13a794368b54ab Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:18:46 +0800 Subject: [PATCH] xtc fixes --- gpttype_adapter.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 63e2f67b1..dacfae5d7 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -501,9 +501,9 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) candidates->size = last_idx; } -void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng, size_t min_keep) +void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng) { - if (xtc_threshold <= 0.0f || xtc_probability <= 0.0f || candidates->size <= 1) { + if (xtc_threshold > 0.5f || xtc_probability <= 0.0f || candidates->size <= 1) { return; } @@ -521,7 +521,7 @@ void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float for (size_t i = 0; i < candidates->size; ++i) { // Go until we reach a value under the threshold float checkprob = candidates->data[i].p; - if (checkprob < xtc_threshold && i >= min_keep) { + if (checkprob < xtc_threshold) { last_idx = i; break; } @@ -529,11 +529,8 @@ void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float if(last_idx>1) //if there are 2 or more viable candidates { - // drop all tokens except those above threshold - candidates->size = last_idx; - - // then remove all other tokens EXCEPT the least likely one - for (size_t i = 0; i < candidates->size - 1; ++i) { + // then remove all other tokens above threshold EXCEPT the least likely one + for (size_t i = 0; i < last_idx - 1; ++i) { candidates->data[i].logit -= 999.0f; //infinity gets wonky results downstream, this hack works well enough } candidates->sorted = false; @@ -956,7 +953,7 @@ const std::vector & sampler_order, llama_grammar * grammar, float dyna } } //xtc always last - sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng, 1); + sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng); id = sample_token(&candidates_p, rng); }