diff --git a/otherarch/tts_adapter.cpp b/otherarch/tts_adapter.cpp index 22d4c1997..897f8e401 100644 --- a/otherarch/tts_adapter.cpp +++ b/otherarch/tts_adapter.cpp @@ -754,8 +754,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) //use creative settings to generate speakers const int topk = 20; + const float top_p = 1.0f; const float temp = 1.2f; - llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng); + llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector(),1.0,top_p,topk,temp,speaker_rng); //guide tokens help prevent hallucinations by forcing the TTS to use the correct word if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id)) @@ -876,7 +877,8 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) //use predictable settings to generate voice const int topk = 4; const float temp = 0.75f; - llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng); + const float top_p = 1.0f; + llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector(),1.0,top_p,topk,temp,speaker_rng); //guide tokens help prevent hallucinations by forcing the TTS to use the correct word if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id)) @@ -933,7 +935,7 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) const int n_codes = codes.size(); if(n_codes<=1) { - printf("\nWarning: TTS vocoder generated nothing!\n"); + printf("\nWarning: No Audio Tokens Produced!\n"); last_generated_audio = ""; output.data = last_generated_audio.c_str(); output.status = 1; @@ -963,12 +965,23 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) //audio = resample_wav(audio,n_sr,t_sr); //resample to 16k - for (int i = 0; i < cutout; ++i) { - audio[i] = 0.0f; + if(audio.size()>cutout+16) + { + for (int i = 0; i < cutout; ++i) { + audio[i] = 0.0f; + } + //add some silence at the end + for (int i = 0; i < cutout; ++i) { + audio.push_back(0.0f); + } } - //add some silence at the end - for (int i = 0; i < cutout; ++i) { - audio.push_back(0.0f); + else + { + printf("\nWarning: TTS vocoder generated nothing!\n"); + last_generated_audio = ""; + output.data = last_generated_audio.c_str(); + output.status = 1; + return output; } last_generated_audio = save_wav16_base64(audio, t_sr); diff --git a/otherarch/utils.cpp b/otherarch/utils.cpp index 9ef515d11..410e7bb13 100644 --- a/otherarch/utils.cpp +++ b/otherarch/utils.cpp @@ -369,9 +369,9 @@ std::vector resample_wav(const std::vector& input, uint32_t input_ } //a very rudimentary all in one sampling function which has no dependencies -int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng) +int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng) { - if (temp <= 0 || top_k==1) { + if (temp <= 0) { // select the token with the highest logit directly float max_logit = logits[0]; int32_t max_id = 0; @@ -392,8 +392,19 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t //temperature sample const float scale = 1.0f/temp; + + //sample rep pen for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); + if (rep_pen>1.0f && std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (logits[i] < 0.0f) { + logits_id.push_back(std::make_pair(logits[i]*scale*rep_pen, i)); + } else { + logits_id.push_back(std::make_pair(logits[i]*scale/rep_pen, i)); + } + } else { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } } //sample top_k @@ -421,6 +432,24 @@ int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float t p /= sum; } + //apply top p + if (top_p < 1.0) { + double cumsum = 0.0; + for (int i = 0; i < (int) probs.size(); i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + probs.resize(i + 1); + logits_id.resize(i + 1); + break; + } + } + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + std::discrete_distribution<> dist(probs.begin(), probs.end()); int idx = dist(rng); diff --git a/otherarch/utils.h b/otherarch/utils.h index 1303a3f6c..78012fd6e 100644 --- a/otherarch/utils.h +++ b/otherarch/utils.h @@ -63,7 +63,7 @@ std::string kcpp_base64_encode(const std::string &data); std::string get_timestamp_str(); std::vector resample_wav(const std::vector& input, uint32_t input_rate, uint32_t output_rate); -int32_t kcpp_quick_sample(float * logits, const int n_logits, int top_k, float temp, std::mt19937 & rng); +int32_t kcpp_quick_sample(float * logits, const int n_logits, const std::vector & last_n_tokens, float rep_pen, float top_p, int top_k, float temp, std::mt19937 & rng); struct kcpp_embd_batch { //duplcated from llava_embd_batch std::vector pos;