diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 92d353285..0419027bb 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1433,12 +1433,11 @@ void sampler_typical(llama_token_data_array * cur_p, float p, size_t min_keep) { } void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) { - if (nsigma <= 0.0f || cur_p->size <= 1) { return; } // find max logit and calculate mean - float nsigmax = cur_p->data[0].logit; + float nsigmax = cur_p->data[0].logit; float logits_sum = 0; for (size_t i = 0; i < cur_p->size; ++i) { if (cur_p->data[i].logit > nsigmax) { @@ -1456,11 +1455,10 @@ void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) { float nsigstd = sqrt(nsigacc / cur_p->size); //apply mask - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit < nsigmax - (nsigma * nsigstd)) { - cur_p->data[i].logit -= 999.0f; - } - } + auto last = std::remove_if(cur_p->data, cur_p->data + cur_p->size, + [&](auto & tk) { return tk.logit < nsigmax - (nsigma * nsigstd); }); + cur_p->size = last - cur_p->data; + sample_softmax(cur_p); }