mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
added support for seeded tts voices
This commit is contained in:
parent
b3de1598e7
commit
62e33d0bf7
2 changed files with 211 additions and 44 deletions
|
@ -369,17 +369,56 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_model * model,
|
|||
|
||||
// Add the last part
|
||||
std::string current_word = str.substr(start);
|
||||
auto tmp = common_tokenize(model, current_word, false, true);
|
||||
result.push_back(tmp[0]);
|
||||
if(current_word!="")
|
||||
{
|
||||
auto tmp = common_tokenize(model, current_word, false, true);
|
||||
if(tmp.size()>0){
|
||||
result.push_back(tmp[0]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string trim_words(const std::string& input, const std::string& separator, size_t maxWords) {
|
||||
// Split the input string by the separator
|
||||
std::vector<std::string> words;
|
||||
size_t start = 0, end;
|
||||
while ((end = input.find(separator, start)) != std::string::npos) {
|
||||
std::string last = input.substr(start, end - start);
|
||||
if (last != "") {
|
||||
words.push_back(last);
|
||||
}
|
||||
start = end + separator.length();
|
||||
}
|
||||
std::string last = input.substr(start);
|
||||
if(last!="")
|
||||
{
|
||||
words.push_back(last); // Add the last word
|
||||
}
|
||||
|
||||
// Ensure no more than maxWords are kept
|
||||
if (words.size() > maxWords) {
|
||||
words.resize(maxWords);
|
||||
}
|
||||
|
||||
// Reconstruct the string with the separator
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < words.size(); ++i) {
|
||||
if (i > 0) result << separator;
|
||||
result << words[i];
|
||||
}
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
static llama_context * ttc_ctx = nullptr; //text to codes ctx
|
||||
static llama_context * cts_ctx = nullptr; //codes to speech
|
||||
|
||||
static int ttsdebugmode = 0;
|
||||
static std::string ttsplatformenv, ttsdeviceenv, ttsvulkandeviceenv;
|
||||
static std::string last_generated_audio = "";
|
||||
static std::vector<llama_token> last_speaker_codes; //will store cached speaker
|
||||
static int last_speaker_seed = -999;
|
||||
|
||||
bool ttstype_load_model(const tts_load_model_inputs inputs)
|
||||
{
|
||||
|
@ -484,14 +523,11 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
const llama_model * model_cts = &(cts_ctx->model);
|
||||
const int ttc_n_vocab = llama_n_vocab(model_ttc);
|
||||
std::string prompt = inputs.prompt;
|
||||
|
||||
if(!inputs.quiet)
|
||||
{
|
||||
printf("\nTTS Generating... ");
|
||||
}
|
||||
const std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is";
|
||||
|
||||
// process prompt and generate voice codes
|
||||
|
||||
llama_kv_cache_clear(ttc_ctx);
|
||||
llama_kv_cache_clear(cts_ctx);
|
||||
std::vector<llama_token> prompt_inp;
|
||||
prompt_init(prompt_inp, model_ttc);
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
|
||||
|
@ -501,39 +537,38 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
if (speaker_seed <= 0 || speaker_seed==0xFFFFFFFF)
|
||||
{
|
||||
speaker_seed = (((uint32_t)time(NULL)) % 1000000u);
|
||||
if(ttsdebugmode==1)
|
||||
{
|
||||
printf("\nUsing Speaker Seed: %d", speaker_seed);
|
||||
}
|
||||
}
|
||||
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF)
|
||||
{
|
||||
audio_seed = (((uint32_t)time(NULL)) % 1000000u);
|
||||
if(ttsdebugmode==1)
|
||||
{
|
||||
printf("\nUsing Audio Seed: %d", audio_seed);
|
||||
}
|
||||
}
|
||||
if(ttsdebugmode==1)
|
||||
{
|
||||
printf("\nUsing Speaker Seed: %d", speaker_seed);
|
||||
printf("\nUsing Audio Seed: %d", audio_seed);
|
||||
}
|
||||
|
||||
std::mt19937 tts_rng(audio_seed);
|
||||
std::mt19937 speaker_rng(speaker_seed);
|
||||
|
||||
//add the speaker based on the seed
|
||||
if(speaker_seed>0)
|
||||
{
|
||||
std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is<|text_sep|>";
|
||||
}
|
||||
int n_decode = 0;
|
||||
int n_predict = 2048; //will be updated later
|
||||
bool next_token_uses_guide_token = true;
|
||||
|
||||
// convert the input text into the necessary format expected by OuteTTS
|
||||
std::string prompt_clean = process_text(prompt);
|
||||
|
||||
//further clean it by keeping only the last 300 words
|
||||
prompt_clean = trim_words(prompt_clean,"<|text_sep|>",300);
|
||||
|
||||
if(prompt_clean.size()==0)
|
||||
{
|
||||
//no input
|
||||
if(!inputs.quiet)
|
||||
{
|
||||
printf("\nTTS sent empty input.\n");
|
||||
output.data = "";
|
||||
last_generated_audio = "";
|
||||
output.data = last_generated_audio.c_str();
|
||||
output.status = 1;
|
||||
return output;
|
||||
}
|
||||
|
@ -544,19 +579,130 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
printf("\nInput: %s\n", prompt_clean.c_str());
|
||||
}
|
||||
|
||||
//2 passes. first pass, we generate the speaker voice if required, then cache it for reuse
|
||||
//second pass, we use the speaker snipper to align output voice to match the desired speaker
|
||||
if(speaker_seed>0) //first pass
|
||||
{
|
||||
//if we have a cached speaker, reuse it
|
||||
if(last_speaker_seed==speaker_seed && !last_speaker_codes.empty())
|
||||
{
|
||||
//able to proceed, do nothing
|
||||
if(!inputs.quiet && ttsdebugmode==1)
|
||||
{
|
||||
printf("\nReuse speaker ID=%d (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
|
||||
}
|
||||
} else {
|
||||
//generate the voice texture of our new speaker
|
||||
last_speaker_codes.clear();
|
||||
guide_tokens = prepare_guide_tokens(model_ttc,sampletext);
|
||||
prompt_add(prompt_inp, model_ttc, sampletext, false, true);
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
|
||||
if(!inputs.quiet && ttsdebugmode==1)
|
||||
{
|
||||
printf("\nPrepare new speaker (%d input tokens)...", prompt_inp.size());
|
||||
}
|
||||
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
|
||||
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
|
||||
if (!evalok) {
|
||||
printf("\nError: TTS prompt batch processing failed\n");
|
||||
output.data = "";
|
||||
output.status = 0;
|
||||
return output;
|
||||
}
|
||||
|
||||
while (n_decode <= n_predict)
|
||||
{
|
||||
float * logits = llama_get_logits(ttc_ctx);
|
||||
|
||||
//use creative settings to generate speakers
|
||||
const int topk = 20;
|
||||
const float temp = 1.2f;
|
||||
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng);
|
||||
|
||||
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
|
||||
if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
|
||||
{
|
||||
if(!guide_tokens.empty())
|
||||
{
|
||||
llama_token guide_token = guide_tokens[0];
|
||||
guide_tokens.erase(guide_tokens.begin());
|
||||
new_token_id = guide_token; //ensure correct word fragment is used
|
||||
} else {
|
||||
n_decode = n_predict; //stop generation
|
||||
}
|
||||
}
|
||||
|
||||
//this is the token id that always precedes a new word
|
||||
next_token_uses_guide_token = (new_token_id == 198);
|
||||
last_speaker_codes.push_back(new_token_id);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode >= n_predict) {
|
||||
break;
|
||||
}
|
||||
|
||||
n_decode += 1;
|
||||
std::vector<llama_token> next = {new_token_id};
|
||||
llama_batch batch = llama_batch_get_one(next.data(), next.size());
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ttc_ctx, batch)) {
|
||||
printf("\nError: TTS code generation failed!\n");
|
||||
output.data = "";
|
||||
output.status = 0;
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
||||
//trim everything after final <|code_end|>
|
||||
auto it = std::find(last_speaker_codes.rbegin(), last_speaker_codes.rend(), 151670);
|
||||
if (it != last_speaker_codes.rend()) {
|
||||
// Erase elements after the found 999 (inclusive)
|
||||
last_speaker_codes.erase(it.base(), last_speaker_codes.end());
|
||||
}
|
||||
last_speaker_seed = speaker_seed;
|
||||
if(!inputs.quiet && ttsdebugmode==1)
|
||||
{
|
||||
printf("\nNew speaker ID=%d created (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
|
||||
const std::string inp_txt = common_detokenize(ttc_ctx, last_speaker_codes, true);
|
||||
printf("\n%s\n", inp_txt.c_str());
|
||||
}
|
||||
}
|
||||
guide_tokens.clear();
|
||||
llama_kv_cache_clear(ttc_ctx);
|
||||
prompt_init(prompt_inp, model_ttc);
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
|
||||
next_token_uses_guide_token = true;
|
||||
}
|
||||
|
||||
//second pass: add the speaker before the actual prompt
|
||||
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
|
||||
if(speaker_seed > 0)
|
||||
{
|
||||
prompt_clean = sampletext + "<|text_sep|>" + prompt_clean;
|
||||
}
|
||||
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
|
||||
|
||||
if(!inputs.quiet)
|
||||
{
|
||||
printf(" (%d input words)...", guide_tokens.size());
|
||||
printf("\nTTS Generating (%d input tokens)...", prompt_inp.size());
|
||||
}
|
||||
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
|
||||
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
|
||||
|
||||
if(!last_speaker_codes.empty() && speaker_seed > 0) //apply speaker voice output
|
||||
{
|
||||
prompt_add(prompt_inp, last_speaker_codes);
|
||||
}
|
||||
|
||||
if(!inputs.quiet && ttsdebugmode==1)
|
||||
{
|
||||
printf("\nDUMP TTS PROMPT (%d tokens):\n", prompt_inp.size());
|
||||
const std::string inp_txt = common_detokenize(ttc_ctx, prompt_inp, true);
|
||||
printf("\n%s\n", inp_txt.c_str());
|
||||
}
|
||||
|
||||
//create batch with tokens for decoding prompt processing
|
||||
llama_kv_cache_clear(ttc_ctx);
|
||||
llama_kv_cache_clear(cts_ctx);
|
||||
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
|
||||
|
||||
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
|
||||
|
@ -568,28 +714,33 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
}
|
||||
|
||||
// main loop
|
||||
int n_decode = 0;
|
||||
int n_predict = 4096; //max 4096 tokens
|
||||
|
||||
bool next_token_uses_guide_token = true;
|
||||
n_decode = 0;
|
||||
n_predict = 4096; //max 4096 tokens
|
||||
|
||||
while (n_decode <= n_predict)
|
||||
{
|
||||
float * logits = llama_get_logits(ttc_ctx);
|
||||
|
||||
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,20,1.0,tts_rng);
|
||||
//use predictable settings to generate voice
|
||||
const int topk = 4;
|
||||
const float temp = 0.75f;
|
||||
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng);
|
||||
|
||||
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
|
||||
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
|
||||
if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
|
||||
{
|
||||
llama_token guide_token = guide_tokens[0];
|
||||
guide_tokens.erase(guide_tokens.begin());
|
||||
new_token_id = guide_token; //ensure correct word fragment is used
|
||||
if(!guide_tokens.empty())
|
||||
{
|
||||
llama_token guide_token = guide_tokens[0];
|
||||
guide_tokens.erase(guide_tokens.begin());
|
||||
new_token_id = guide_token; //ensure correct word fragment is used
|
||||
} else {
|
||||
n_decode = n_predict; //end generation
|
||||
}
|
||||
}
|
||||
|
||||
//this is the token id that always precedes a new word
|
||||
next_token_uses_guide_token = (new_token_id == 198);
|
||||
|
||||
codes.push_back(new_token_id);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
|
@ -613,7 +764,6 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
if(!inputs.quiet && ttsdebugmode==1)
|
||||
{
|
||||
const std::string inp_txt = common_detokenize(ttc_ctx, codes, true);
|
||||
|
||||
printf("\nGenerated %d Codes: '%s'\n",codes.size(), inp_txt.c_str());
|
||||
}
|
||||
|
||||
|
@ -628,8 +778,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
if(n_codes<=1)
|
||||
{
|
||||
printf("\nWarning: TTS vocoder generated nothing!\n");
|
||||
output.data = "";
|
||||
output.status = 0;
|
||||
last_generated_audio = "";
|
||||
output.data = last_generated_audio.c_str();
|
||||
output.status = 1;
|
||||
return output;
|
||||
}
|
||||
kcpp_embd_batch codebatch = kcpp_embd_batch(codes,0,false,true);
|
||||
|
@ -649,8 +800,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
|
|||
|
||||
const int n_sr = 24000; // sampling rate
|
||||
|
||||
// zero out first 0.05 seconds
|
||||
for (int i = 0; i < 24000/20; ++i) {
|
||||
// zero out first 0.25 seconds or 0.05 depending on whether its seeded
|
||||
const int cutout = (speaker_seed>0?(24000/4):(24000/20));
|
||||
for (int i = 0; i < cutout; ++i) {
|
||||
audio[i] = 0.0f;
|
||||
}
|
||||
//add some silence at the end
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue