perplexity : fix even more integer overflows (#23623)

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
fairydreaming 2026-05-25 07:12:39 +02:00 committed by GitHub
parent 28123a3937
commit 6d57c26ef8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -157,7 +157,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
break;
}
lock.unlock();
const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + size_t(i)*nv, tokens[i+1]);
local_nll += v;
local_nll2 += v*v;
}
@ -169,7 +169,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
for (auto & w : workers) {
w.join();
}
out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
out.write((const char *)log_probs.data(), size_t(n_token)*nv*sizeof(uint16_t));
}
struct kl_divergence_result {
@ -279,7 +279,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
break;
}
lock.unlock();
std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + size_t(i)*nv, tokens[i+1], local_kld);
kld_values[i] = (float)v.first;
p_diff_values[i] = v.second;
}