improved sampling for tts and fixed yet another bug. no patch release for this.

This commit is contained in:
Concedo 2025-01-23 23:37:13 +08:00
parent cca4a934dd
commit fb1274e100
3 changed files with 54 additions and 12 deletions

View file

@ -754,8 +754,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
//use creative settings to generate speakers
const int topk = 20;
const float top_p = 1.0f;
const float temp = 1.2f;
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng);
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector<int32_t>(),1.0,top_p,topk,temp,speaker_rng);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id))
@ -876,7 +877,8 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
//use predictable settings to generate voice
const int topk = 4;
const float temp = 0.75f;
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng);
const float top_p = 1.0f;
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector<int32_t>(),1.0,top_p,topk,temp,speaker_rng);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id))
@ -933,7 +935,7 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
const int n_codes = codes.size();
if(n_codes<=1)
{
printf("\nWarning: TTS vocoder generated nothing!\n");
printf("\nWarning: No Audio Tokens Produced!\n");
last_generated_audio = "";
output.data = last_generated_audio.c_str();
output.status = 1;
@ -963,6 +965,8 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
//audio = resample_wav(audio,n_sr,t_sr); //resample to 16k
if(audio.size()>cutout+16)
{
for (int i = 0; i < cutout; ++i) {
audio[i] = 0.0f;
}
@ -970,6 +974,15 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
for (int i = 0; i < cutout; ++i) {
audio.push_back(0.0f);
}
}
else
{
printf("\nWarning: TTS vocoder generated nothing!\n");
last_generated_audio = "";
output.data = last_generated_audio.c_str();
output.status = 1;
return output;
}
last_generated_audio = save_wav16_base64(audio, t_sr);
ttstime = timer_check();

View file

@ -369,9 +369,9 @@ std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_
}
//a very rudimentary all in one sampling function which has no dependencies
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng)
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng)
{
if (temp <= 0 || top_k==1) {
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = logits[0];
int32_t max_id = 0;
@ -392,9 +392,20 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
//temperature sample
const float scale = 1.0f/temp;
//sample rep pen
for (int i = 0; i < n_logits; ++i) {
if (rep_pen>1.0f && std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (logits[i] < 0.0f) {
logits_id.push_back(std::make_pair(logits[i]*scale*rep_pen, i));
} else {
logits_id.push_back(std::make_pair(logits[i]*scale/rep_pen, i));
}
} else {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}
//sample top_k
std::partial_sort(
@ -421,6 +432,24 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t
p /= sum;
}
//apply top p
if (top_p < 1.0) {
double cumsum = 0.0;
for (int i = 0; i < (int) probs.size(); i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
probs.resize(i + 1);
logits_id.resize(i + 1);
break;
}
}
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);

View file

@ -63,7 +63,7 @@ std::string kcpp_base64_encode(const std::string &data);
std::string get_timestamp_str();
std::vector<float> resample_wav(const std::vector<float>& input, uint32_t input_rate, uint32_t output_rate);
int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng);
int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector<int32_t> & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng);
struct kcpp_embd_batch { //duplcated from llava_embd_batch
std::vector<int32_t> pos;