improved sampling for tts and fixed yet another bug. no patch release for this.

This commit is contained in:
Concedo 2025-01-23 23:37:13 +08:00
parent cca4a934dd
commit fb1274e100
3 changed files with 54 additions and 12 deletions

View file

@ -369,9 +369,9 @@ std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_
}
//a very rudimentary all in one sampling function which has no dependencies
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng)
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng)
{
if (temp <= 0 || top_k==1) {
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = logits[0];
int32_t max_id = 0;
@ -392,8 +392,19 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
//temperature sample
const float scale = 1.0f/temp;
//sample rep pen
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
if (rep_pen>1.0f && std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (logits[i] < 0.0f) {
logits_id.push_back(std::make_pair(logits[i]*scale*rep_pen, i));
} else {
logits_id.push_back(std::make_pair(logits[i]*scale/rep_pen, i));
}
} else {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}
//sample top_k
@ -421,6 +432,24 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
p /= sum;
}
//apply top p
if (top_p < 1.0) {
double cumsum = 0.0;
for (int i = 0; i < (int) probs.size(); i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
probs.resize(i + 1);
logits_id.resize(i + 1);
break;
}
}
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);