mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
improved sampling for tts and fixed yet another bug. no patch release for this.
This commit is contained in:
parent
cca4a934dd
commit
fb1274e100
3 changed files with 54 additions and 12 deletions
|
@ -369,9 +369,9 @@ std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_
|
|||
}
|
||||
|
||||
//a very rudimentary all in one sampling function which has no dependencies
|
||||
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng)
|
||||
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng)
|
||||
{
|
||||
if (temp <= 0 || top_k==1) {
|
||||
if (temp <= 0) {
|
||||
// select the token with the highest logit directly
|
||||
float max_logit = logits[0];
|
||||
int32_t max_id = 0;
|
||||
|
@ -392,8 +392,19 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
|
|||
|
||||
//temperature sample
|
||||
const float scale = 1.0f/temp;
|
||||
|
||||
//sample rep pen
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
if (rep_pen>1.0f && std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if (logits[i] < 0.0f) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale*rep_pen, i));
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale/rep_pen, i));
|
||||
}
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
|
||||
//sample top_k
|
||||
|
@ -421,6 +432,24 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
|
|||
p /= sum;
|
||||
}
|
||||
|
||||
//apply top p
|
||||
if (top_p < 1.0) {
|
||||
double cumsum = 0.0;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
probs.resize(i + 1);
|
||||
logits_id.resize(i + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue