added support for seeded tts voices

This commit is contained in:
Concedo 2025-01-13 19:11:34 +08:00
parent b3de1598e7
commit 62e33d0bf7
2 changed files with 211 additions and 44 deletions

View file

@ -622,6 +622,16 @@ def bring_terminal_to_foreground():
ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9)
ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow())
def simple_lcg_hash(input_string): #turns any string into a number between 10000 and 99999
a = 1664525
c = 1013904223
m = 89999 # Modulo
hash_value = 25343
for char in input_string:
hash_value = (a * hash_value + ord(char) + c) % m
hash_value += 10000
return hash_value
def string_has_overlap(str_a, str_b, maxcheck):
max_overlap = min(maxcheck, len(str_a), len(str_b))
for i in range(1, max_overlap + 1):
@ -1331,11 +1341,13 @@ def tts_load_model(ttc_model_filename,cts_model_filename):
def tts_generate(genparams):
global args
is_quiet = True if (args.quiet or args.debugmode == -1) else False
prompt = genparams.get("input", "")
prompt = genparams.get("input", genparams.get("text", ""))
prompt = prompt.strip()
voicestr = genparams.get("voice", genparams.get("speaker_wav", ""))
voice = simple_lcg_hash(voicestr) if voicestr else 1
inputs = tts_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.speaker_seed = 0
inputs.speaker_seed = voice
inputs.audio_seed = 0
inputs.quiet = is_quiet
ret = handle.tts_generate(inputs)
@ -2296,6 +2308,9 @@ Enter Prompt:<br>
elif self.path.endswith('/sdapi/v1/upscalers'):
response_body = (json.dumps([]).encode())
elif self.path.endswith(('/speakers_list')): #xtts compatible
response_body = (json.dumps(["kobo","bean","corn","spicy","lime","fire","metal","potato"]).encode()) #some random voices for them to enjoy
elif self.path.endswith(('/api/tags')): #ollama compatible
response_body = (json.dumps({"models":[{"name":"koboldcpp","model":friendlymodelname,"modified_at":"2024-07-19T15:26:55.6122841+08:00","size":394998579,"digest":"b5dc5e784f2a3ee1582373093acf69a2f4e2ac1710b253a001712b86a61f88bb","details":{"parent_model":"","format":"gguf","family":"koboldcpp","families":["koboldcpp"],"parameter_size":"128M","quantization_level":"Q4_0"}}]}).encode())
@ -2671,7 +2686,7 @@ Enter Prompt:<br>
if self.path.endswith('/api/extra/transcribe') or self.path.endswith('/v1/audio/transcriptions'):
is_transcribe = True
if self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech'):
if self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech') or self.path.endswith('/tts_to_audio'):
is_tts = True
if is_imggen or is_transcribe or is_tts or api_format > 0:

View file

@ -369,17 +369,56 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_model * model,
// Add the last part
std::string current_word = str.substr(start);
auto tmp = common_tokenize(model, current_word, false, true);
result.push_back(tmp[0]);
if(current_word!="")
{
auto tmp = common_tokenize(model, current_word, false, true);
if(tmp.size()>0){
result.push_back(tmp[0]);
}
}
return result;
}
std::string trim_words(const std::string& input, const std::string& separator, size_t maxWords) {
// Split the input string by the separator
std::vector<std::string> words;
size_t start = 0, end;
while ((end = input.find(separator, start)) != std::string::npos) {
std::string last = input.substr(start, end - start);
if (last != "") {
words.push_back(last);
}
start = end + separator.length();
}
std::string last = input.substr(start);
if(last!="")
{
words.push_back(last); // Add the last word
}
// Ensure no more than maxWords are kept
if (words.size() > maxWords) {
words.resize(maxWords);
}
// Reconstruct the string with the separator
std::ostringstream result;
for (size_t i = 0; i < words.size(); ++i) {
if (i > 0) result << separator;
result << words[i];
}
return result.str();
}
static llama_context * ttc_ctx = nullptr; //text to codes ctx
static llama_context * cts_ctx = nullptr; //codes to speech
static int ttsdebugmode = 0;
static std::string ttsplatformenv, ttsdeviceenv, ttsvulkandeviceenv;
static std::string last_generated_audio = "";
static std::vector<llama_token> last_speaker_codes; //will store cached speaker
static int last_speaker_seed = -999;
bool ttstype_load_model(const tts_load_model_inputs inputs)
{
@ -484,14 +523,11 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
const llama_model * model_cts = &(cts_ctx->model);
const int ttc_n_vocab = llama_n_vocab(model_ttc);
std::string prompt = inputs.prompt;
if(!inputs.quiet)
{
printf("\nTTS Generating... ");
}
const std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is";
// process prompt and generate voice codes
llama_kv_cache_clear(ttc_ctx);
llama_kv_cache_clear(cts_ctx);
std::vector<llama_token> prompt_inp;
prompt_init(prompt_inp, model_ttc);
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
@ -501,39 +537,38 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
if (speaker_seed <= 0 || speaker_seed==0xFFFFFFFF)
{
speaker_seed = (((uint32_t)time(NULL)) % 1000000u);
if(ttsdebugmode==1)
{
printf("\nUsing Speaker Seed: %d", speaker_seed);
}
}
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF)
{
audio_seed = (((uint32_t)time(NULL)) % 1000000u);
if(ttsdebugmode==1)
{
printf("\nUsing Audio Seed: %d", audio_seed);
}
}
if(ttsdebugmode==1)
{
printf("\nUsing Speaker Seed: %d", speaker_seed);
printf("\nUsing Audio Seed: %d", audio_seed);
}
std::mt19937 tts_rng(audio_seed);
std::mt19937 speaker_rng(speaker_seed);
//add the speaker based on the seed
if(speaker_seed>0)
{
std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is<|text_sep|>";
}
int n_decode = 0;
int n_predict = 2048; //will be updated later
bool next_token_uses_guide_token = true;
// convert the input text into the necessary format expected by OuteTTS
std::string prompt_clean = process_text(prompt);
//further clean it by keeping only the last 300 words
prompt_clean = trim_words(prompt_clean,"<|text_sep|>",300);
if(prompt_clean.size()==0)
{
//no input
if(!inputs.quiet)
{
printf("\nTTS sent empty input.\n");
output.data = "";
last_generated_audio = "";
output.data = last_generated_audio.c_str();
output.status = 1;
return output;
}
@ -544,19 +579,130 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
printf("\nInput: %s\n", prompt_clean.c_str());
}
//2 passes. first pass, we generate the speaker voice if required, then cache it for reuse
//second pass, we use the speaker snipper to align output voice to match the desired speaker
if(speaker_seed>0) //first pass
{
//if we have a cached speaker, reuse it
if(last_speaker_seed==speaker_seed && !last_speaker_codes.empty())
{
//able to proceed, do nothing
if(!inputs.quiet && ttsdebugmode==1)
{
printf("\nReuse speaker ID=%d (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
}
} else {
//generate the voice texture of our new speaker
last_speaker_codes.clear();
guide_tokens = prepare_guide_tokens(model_ttc,sampletext);
prompt_add(prompt_inp, model_ttc, sampletext, false, true);
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
if(!inputs.quiet && ttsdebugmode==1)
{
printf("\nPrepare new speaker (%d input tokens)...", prompt_inp.size());
}
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
if (!evalok) {
printf("\nError: TTS prompt batch processing failed\n");
output.data = "";
output.status = 0;
return output;
}
while (n_decode <= n_predict)
{
float * logits = llama_get_logits(ttc_ctx);
//use creative settings to generate speakers
const int topk = 20;
const float temp = 1.2f;
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
{
if(!guide_tokens.empty())
{
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
} else {
n_decode = n_predict; //stop generation
}
}
//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);
last_speaker_codes.push_back(new_token_id);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode >= n_predict) {
break;
}
n_decode += 1;
std::vector<llama_token> next = {new_token_id};
llama_batch batch = llama_batch_get_one(next.data(), next.size());
// evaluate the current batch with the transformer model
if (llama_decode(ttc_ctx, batch)) {
printf("\nError: TTS code generation failed!\n");
output.data = "";
output.status = 0;
return output;
}
}
//trim everything after final <|code_end|>
auto it = std::find(last_speaker_codes.rbegin(), last_speaker_codes.rend(), 151670);
if (it != last_speaker_codes.rend()) {
// Erase elements after the found 999 (inclusive)
last_speaker_codes.erase(it.base(), last_speaker_codes.end());
}
last_speaker_seed = speaker_seed;
if(!inputs.quiet && ttsdebugmode==1)
{
printf("\nNew speaker ID=%d created (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
const std::string inp_txt = common_detokenize(ttc_ctx, last_speaker_codes, true);
printf("\n%s\n", inp_txt.c_str());
}
}
guide_tokens.clear();
llama_kv_cache_clear(ttc_ctx);
prompt_init(prompt_inp, model_ttc);
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
next_token_uses_guide_token = true;
}
//second pass: add the speaker before the actual prompt
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
if(speaker_seed > 0)
{
prompt_clean = sampletext + "<|text_sep|>" + prompt_clean;
}
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
if(!inputs.quiet)
{
printf(" (%d input words)...", guide_tokens.size());
printf("\nTTS Generating (%d input tokens)...", prompt_inp.size());
}
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
if(!last_speaker_codes.empty() && speaker_seed > 0) //apply speaker voice output
{
prompt_add(prompt_inp, last_speaker_codes);
}
if(!inputs.quiet && ttsdebugmode==1)
{
printf("\nDUMP TTS PROMPT (%d tokens):\n", prompt_inp.size());
const std::string inp_txt = common_detokenize(ttc_ctx, prompt_inp, true);
printf("\n%s\n", inp_txt.c_str());
}
//create batch with tokens for decoding prompt processing
llama_kv_cache_clear(ttc_ctx);
llama_kv_cache_clear(cts_ctx);
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
@ -568,28 +714,33 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
}
// main loop
int n_decode = 0;
int n_predict = 4096; //max 4096 tokens
bool next_token_uses_guide_token = true;
n_decode = 0;
n_predict = 4096; //max 4096 tokens
while (n_decode <= n_predict)
{
float * logits = llama_get_logits(ttc_ctx);
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,20,1.0,tts_rng);
//use predictable settings to generate voice
const int topk = 4;
const float temp = 0.75f;
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
{
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
if(!guide_tokens.empty())
{
llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used
} else {
n_decode = n_predict; //end generation
}
}
//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);
codes.push_back(new_token_id);
// is it an end of generation? -> mark the stream as finished
@ -613,7 +764,6 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
if(!inputs.quiet && ttsdebugmode==1)
{
const std::string inp_txt = common_detokenize(ttc_ctx, codes, true);
printf("\nGenerated %d Codes: '%s'\n",codes.size(), inp_txt.c_str());
}
@ -628,8 +778,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
if(n_codes<=1)
{
printf("\nWarning: TTS vocoder generated nothing!\n");
output.data = "";
output.status = 0;
last_generated_audio = "";
output.data = last_generated_audio.c_str();
output.status = 1;
return output;
}
kcpp_embd_batch codebatch = kcpp_embd_batch(codes,0,false,true);
@ -649,8 +800,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
const int n_sr = 24000; // sampling rate
// zero out first 0.05 seconds
for (int i = 0; i < 24000/20; ++i) {
// zero out first 0.25 seconds or 0.05 depending on whether its seeded
const int cutout = (speaker_seed>0?(24000/4):(24000/20));
for (int i = 0; i < cutout; ++i) {
audio[i] = 0.0f;
}
//add some silence at the end