From 62e33d0bf7adf5ff3c6c5957a20e4e4842d1fb6b Mon Sep 17 00:00:00 2001
From: Concedo <39025047+LostRuins@users.noreply.github.com>
Date: Mon, 13 Jan 2025 19:11:34 +0800
Subject: [PATCH] added support for seeded tts voices
---
koboldcpp.py | 21 +++-
otherarch/tts_adapter.cpp | 234 +++++++++++++++++++++++++++++++-------
2 files changed, 211 insertions(+), 44 deletions(-)
diff --git a/koboldcpp.py b/koboldcpp.py
index 6f878191d..e70a28139 100644
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -622,6 +622,16 @@ def bring_terminal_to_foreground():
ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9)
ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow())
+def simple_lcg_hash(input_string): #turns any string into a number between 10000 and 99999
+ a = 1664525
+ c = 1013904223
+ m = 89999 # Modulo
+ hash_value = 25343
+ for char in input_string:
+ hash_value = (a * hash_value + ord(char) + c) % m
+ hash_value += 10000
+ return hash_value
+
def string_has_overlap(str_a, str_b, maxcheck):
max_overlap = min(maxcheck, len(str_a), len(str_b))
for i in range(1, max_overlap + 1):
@@ -1331,11 +1341,13 @@ def tts_load_model(ttc_model_filename,cts_model_filename):
def tts_generate(genparams):
global args
is_quiet = True if (args.quiet or args.debugmode == -1) else False
- prompt = genparams.get("input", "")
+ prompt = genparams.get("input", genparams.get("text", ""))
prompt = prompt.strip()
+ voicestr = genparams.get("voice", genparams.get("speaker_wav", ""))
+ voice = simple_lcg_hash(voicestr) if voicestr else 1
inputs = tts_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
- inputs.speaker_seed = 0
+ inputs.speaker_seed = voice
inputs.audio_seed = 0
inputs.quiet = is_quiet
ret = handle.tts_generate(inputs)
@@ -2296,6 +2308,9 @@ Enter Prompt:
elif self.path.endswith('/sdapi/v1/upscalers'):
response_body = (json.dumps([]).encode())
+ elif self.path.endswith(('/speakers_list')): #xtts compatible
+ response_body = (json.dumps(["kobo","bean","corn","spicy","lime","fire","metal","potato"]).encode()) #some random voices for them to enjoy
+
elif self.path.endswith(('/api/tags')): #ollama compatible
response_body = (json.dumps({"models":[{"name":"koboldcpp","model":friendlymodelname,"modified_at":"2024-07-19T15:26:55.6122841+08:00","size":394998579,"digest":"b5dc5e784f2a3ee1582373093acf69a2f4e2ac1710b253a001712b86a61f88bb","details":{"parent_model":"","format":"gguf","family":"koboldcpp","families":["koboldcpp"],"parameter_size":"128M","quantization_level":"Q4_0"}}]}).encode())
@@ -2671,7 +2686,7 @@ Enter Prompt:
if self.path.endswith('/api/extra/transcribe') or self.path.endswith('/v1/audio/transcriptions'):
is_transcribe = True
- if self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech'):
+ if self.path.endswith('/api/extra/tts') or self.path.endswith('/v1/audio/speech') or self.path.endswith('/tts_to_audio'):
is_tts = True
if is_imggen or is_transcribe or is_tts or api_format > 0:
diff --git a/otherarch/tts_adapter.cpp b/otherarch/tts_adapter.cpp
index faef59d9c..4311346ac 100644
--- a/otherarch/tts_adapter.cpp
+++ b/otherarch/tts_adapter.cpp
@@ -369,17 +369,56 @@ static std::vector prepare_guide_tokens(const llama_model * model,
// Add the last part
std::string current_word = str.substr(start);
- auto tmp = common_tokenize(model, current_word, false, true);
- result.push_back(tmp[0]);
+ if(current_word!="")
+ {
+ auto tmp = common_tokenize(model, current_word, false, true);
+ if(tmp.size()>0){
+ result.push_back(tmp[0]);
+ }
+ }
return result;
}
+std::string trim_words(const std::string& input, const std::string& separator, size_t maxWords) {
+ // Split the input string by the separator
+ std::vector words;
+ size_t start = 0, end;
+ while ((end = input.find(separator, start)) != std::string::npos) {
+ std::string last = input.substr(start, end - start);
+ if (last != "") {
+ words.push_back(last);
+ }
+ start = end + separator.length();
+ }
+ std::string last = input.substr(start);
+ if(last!="")
+ {
+ words.push_back(last); // Add the last word
+ }
+
+ // Ensure no more than maxWords are kept
+ if (words.size() > maxWords) {
+ words.resize(maxWords);
+ }
+
+ // Reconstruct the string with the separator
+ std::ostringstream result;
+ for (size_t i = 0; i < words.size(); ++i) {
+ if (i > 0) result << separator;
+ result << words[i];
+ }
+
+ return result.str();
+}
+
static llama_context * ttc_ctx = nullptr; //text to codes ctx
static llama_context * cts_ctx = nullptr; //codes to speech
static int ttsdebugmode = 0;
static std::string ttsplatformenv, ttsdeviceenv, ttsvulkandeviceenv;
static std::string last_generated_audio = "";
+static std::vector last_speaker_codes; //will store cached speaker
+static int last_speaker_seed = -999;
bool ttstype_load_model(const tts_load_model_inputs inputs)
{
@@ -484,14 +523,11 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
const llama_model * model_cts = &(cts_ctx->model);
const int ttc_n_vocab = llama_n_vocab(model_ttc);
std::string prompt = inputs.prompt;
-
- if(!inputs.quiet)
- {
- printf("\nTTS Generating... ");
- }
+ const std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is";
// process prompt and generate voice codes
-
+ llama_kv_cache_clear(ttc_ctx);
+ llama_kv_cache_clear(cts_ctx);
std::vector prompt_inp;
prompt_init(prompt_inp, model_ttc);
prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
@@ -501,39 +537,38 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
if (speaker_seed <= 0 || speaker_seed==0xFFFFFFFF)
{
speaker_seed = (((uint32_t)time(NULL)) % 1000000u);
- if(ttsdebugmode==1)
- {
- printf("\nUsing Speaker Seed: %d", speaker_seed);
- }
}
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF)
{
audio_seed = (((uint32_t)time(NULL)) % 1000000u);
- if(ttsdebugmode==1)
- {
- printf("\nUsing Audio Seed: %d", audio_seed);
- }
+ }
+ if(ttsdebugmode==1)
+ {
+ printf("\nUsing Speaker Seed: %d", speaker_seed);
+ printf("\nUsing Audio Seed: %d", audio_seed);
}
std::mt19937 tts_rng(audio_seed);
std::mt19937 speaker_rng(speaker_seed);
- //add the speaker based on the seed
- if(speaker_seed>0)
- {
- std::string sampletext = "but<|text_sep|>that<|text_sep|>is<|text_sep|>what<|text_sep|>it<|text_sep|>is<|text_sep|>";
- }
+ int n_decode = 0;
+ int n_predict = 2048; //will be updated later
+ bool next_token_uses_guide_token = true;
// convert the input text into the necessary format expected by OuteTTS
std::string prompt_clean = process_text(prompt);
+ //further clean it by keeping only the last 300 words
+ prompt_clean = trim_words(prompt_clean,"<|text_sep|>",300);
+
if(prompt_clean.size()==0)
{
//no input
if(!inputs.quiet)
{
printf("\nTTS sent empty input.\n");
- output.data = "";
+ last_generated_audio = "";
+ output.data = last_generated_audio.c_str();
output.status = 1;
return output;
}
@@ -544,19 +579,130 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
printf("\nInput: %s\n", prompt_clean.c_str());
}
+ //2 passes. first pass, we generate the speaker voice if required, then cache it for reuse
+ //second pass, we use the speaker snipper to align output voice to match the desired speaker
+ if(speaker_seed>0) //first pass
+ {
+ //if we have a cached speaker, reuse it
+ if(last_speaker_seed==speaker_seed && !last_speaker_codes.empty())
+ {
+ //able to proceed, do nothing
+ if(!inputs.quiet && ttsdebugmode==1)
+ {
+ printf("\nReuse speaker ID=%d (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
+ }
+ } else {
+ //generate the voice texture of our new speaker
+ last_speaker_codes.clear();
+ guide_tokens = prepare_guide_tokens(model_ttc,sampletext);
+ prompt_add(prompt_inp, model_ttc, sampletext, false, true);
+ prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
+ if(!inputs.quiet && ttsdebugmode==1)
+ {
+ printf("\nPrepare new speaker (%d input tokens)...", prompt_inp.size());
+ }
+ kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
+ auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
+ if (!evalok) {
+ printf("\nError: TTS prompt batch processing failed\n");
+ output.data = "";
+ output.status = 0;
+ return output;
+ }
+
+ while (n_decode <= n_predict)
+ {
+ float * logits = llama_get_logits(ttc_ctx);
+
+ //use creative settings to generate speakers
+ const int topk = 20;
+ const float temp = 1.2f;
+ llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,speaker_rng);
+
+ //guide tokens help prevent hallucinations by forcing the TTS to use the correct word
+ if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
+ {
+ if(!guide_tokens.empty())
+ {
+ llama_token guide_token = guide_tokens[0];
+ guide_tokens.erase(guide_tokens.begin());
+ new_token_id = guide_token; //ensure correct word fragment is used
+ } else {
+ n_decode = n_predict; //stop generation
+ }
+ }
+
+ //this is the token id that always precedes a new word
+ next_token_uses_guide_token = (new_token_id == 198);
+ last_speaker_codes.push_back(new_token_id);
+
+ // is it an end of generation? -> mark the stream as finished
+ if (llama_token_is_eog(model_ttc, new_token_id) || n_decode >= n_predict) {
+ break;
+ }
+
+ n_decode += 1;
+ std::vector next = {new_token_id};
+ llama_batch batch = llama_batch_get_one(next.data(), next.size());
+
+ // evaluate the current batch with the transformer model
+ if (llama_decode(ttc_ctx, batch)) {
+ printf("\nError: TTS code generation failed!\n");
+ output.data = "";
+ output.status = 0;
+ return output;
+ }
+ }
+
+ //trim everything after final <|code_end|>
+ auto it = std::find(last_speaker_codes.rbegin(), last_speaker_codes.rend(), 151670);
+ if (it != last_speaker_codes.rend()) {
+ // Erase elements after the found 999 (inclusive)
+ last_speaker_codes.erase(it.base(), last_speaker_codes.end());
+ }
+ last_speaker_seed = speaker_seed;
+ if(!inputs.quiet && ttsdebugmode==1)
+ {
+ printf("\nNew speaker ID=%d created (%d tokens)...", last_speaker_seed, last_speaker_codes.size());
+ const std::string inp_txt = common_detokenize(ttc_ctx, last_speaker_codes, true);
+ printf("\n%s\n", inp_txt.c_str());
+ }
+ }
+ guide_tokens.clear();
+ llama_kv_cache_clear(ttc_ctx);
+ prompt_init(prompt_inp, model_ttc);
+ prompt_add(prompt_inp, model_ttc, "<|text_start|>", false, true);
+ next_token_uses_guide_token = true;
+ }
+
+ //second pass: add the speaker before the actual prompt
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
+ if(speaker_seed > 0)
+ {
+ prompt_clean = sampletext + "<|text_sep|>" + prompt_clean;
+ }
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
if(!inputs.quiet)
{
- printf(" (%d input words)...", guide_tokens.size());
+ printf("\nTTS Generating (%d input tokens)...", prompt_inp.size());
}
- prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true);
+ prompt_add(prompt_inp, model_ttc, "<|text_end|>\n<|audio_start|>\n", false, true);
+
+ if(!last_speaker_codes.empty() && speaker_seed > 0) //apply speaker voice output
+ {
+ prompt_add(prompt_inp, last_speaker_codes);
+ }
+
+ if(!inputs.quiet && ttsdebugmode==1)
+ {
+ printf("\nDUMP TTS PROMPT (%d tokens):\n", prompt_inp.size());
+ const std::string inp_txt = common_detokenize(ttc_ctx, prompt_inp, true);
+ printf("\n%s\n", inp_txt.c_str());
+ }
//create batch with tokens for decoding prompt processing
- llama_kv_cache_clear(ttc_ctx);
- llama_kv_cache_clear(cts_ctx);
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, true);
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0);
@@ -568,28 +714,33 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
}
// main loop
- int n_decode = 0;
- int n_predict = 4096; //max 4096 tokens
-
- bool next_token_uses_guide_token = true;
+ n_decode = 0;
+ n_predict = 4096; //max 4096 tokens
while (n_decode <= n_predict)
{
float * logits = llama_get_logits(ttc_ctx);
- llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,20,1.0,tts_rng);
+ //use predictable settings to generate voice
+ const int topk = 4;
+ const float temp = 0.75f;
+ llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,topk,temp,tts_rng);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
- if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
+ if(next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
{
- llama_token guide_token = guide_tokens[0];
- guide_tokens.erase(guide_tokens.begin());
- new_token_id = guide_token; //ensure correct word fragment is used
+ if(!guide_tokens.empty())
+ {
+ llama_token guide_token = guide_tokens[0];
+ guide_tokens.erase(guide_tokens.begin());
+ new_token_id = guide_token; //ensure correct word fragment is used
+ } else {
+ n_decode = n_predict; //end generation
+ }
}
//this is the token id that always precedes a new word
next_token_uses_guide_token = (new_token_id == 198);
-
codes.push_back(new_token_id);
// is it an end of generation? -> mark the stream as finished
@@ -613,7 +764,6 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
if(!inputs.quiet && ttsdebugmode==1)
{
const std::string inp_txt = common_detokenize(ttc_ctx, codes, true);
-
printf("\nGenerated %d Codes: '%s'\n",codes.size(), inp_txt.c_str());
}
@@ -628,8 +778,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
if(n_codes<=1)
{
printf("\nWarning: TTS vocoder generated nothing!\n");
- output.data = "";
- output.status = 0;
+ last_generated_audio = "";
+ output.data = last_generated_audio.c_str();
+ output.status = 1;
return output;
}
kcpp_embd_batch codebatch = kcpp_embd_batch(codes,0,false,true);
@@ -649,8 +800,9 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs)
const int n_sr = 24000; // sampling rate
- // zero out first 0.05 seconds
- for (int i = 0; i < 24000/20; ++i) {
+ // zero out first 0.25 seconds or 0.05 depending on whether its seeded
+ const int cutout = (speaker_seed>0?(24000/4):(24000/20));
+ for (int i = 0; i < cutout; ++i) {
audio[i] = 0.0f;
}
//add some silence at the end