smoosh the tfs logic for performance (#1393)

This commit is contained in:
Reithan 2025-02-28 07:55:06 -08:00 committed by GitHub
parent dd6e4038ea
commit 202c029924
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1321,33 +1321,24 @@ void sample_tail_free(llama_token_data_array * cur_p, float z, size_t min_keep)
sample_softmax(cur_p); sample_softmax(cur_p);
// Compute the first and second derivatives // Compute the first and second derivatives
std::vector<float> first_derivatives(cur_p->size - 1);
std::vector<float> second_derivatives(cur_p->size - 2); std::vector<float> second_derivatives(cur_p->size - 2);
float second_derivatives_sum = 0.0f;
for (size_t i = 0; i < first_derivatives.size(); ++i) {
first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
}
for (size_t i = 0; i < second_derivatives.size(); ++i) { for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; float first_derivatives_1 = cur_p->data[i].p - cur_p->data[i + 1].p;
} float first_derivatives_2 = cur_p->data[i + 1].p - cur_p->data[i + 2].p;
second_derivatives[i] = std::abs(first_derivatives_1 - first_derivatives_2);
// Calculate absolute value of second derivatives second_derivatives_sum += second_derivatives[i];
for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = std::abs(second_derivatives[i]);
} }
// Normalize the second derivatives // Normalize the second derivatives
{ if (second_derivatives_sum > 1e-6f) {
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); for (float & value : second_derivatives) {
value /= second_derivatives_sum;
if (second_derivatives_sum > 1e-6f) { }
for (float & value : second_derivatives) { } else {
value /= second_derivatives_sum; for (float & value : second_derivatives) {
} value = 1.0f / second_derivatives.size();
} else {
for (float & value : second_derivatives) {
value = 1.0f / second_derivatives.size();
}
} }
} }