improve performance by actually applying nsigma's masking (#1602)

merging, please report any issues.
This commit is contained in:
Reithan 2025-07-07 00:41:46 -07:00 committed by GitHub
parent 57ce374240
commit 0097de5c57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1433,12 +1433,11 @@ void sampler_typical(llama_token_data_array * cur_p, float p, size_t min_keep) {
}
void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
if (nsigma <= 0.0f || cur_p->size <= 1) {
return;
}
// find max logit and calculate mean
float nsigmax = cur_p->data[0].logit;
float nsigmax = cur_p->data[0].logit;
float logits_sum = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > nsigmax) {
@ -1456,11 +1455,10 @@ void sample_top_n_sigma(llama_token_data_array * cur_p, float nsigma) {
float nsigstd = sqrt(nsigacc / cur_p->size);
//apply mask
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit < nsigmax - (nsigma * nsigstd)) {
cur_p->data[i].logit -= 999.0f;
}
}
auto last = std::remove_if(cur_p->data, cur_p->data + cur_p->size,
[&](auto & tk) { return tk.logit < nsigmax - (nsigma * nsigstd); });
cur_p->size = last - cur_p->data;
sample_softmax(cur_p);
}