mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 08:00:25 +00:00
add some default voices for qwen3tts
This commit is contained in:
parent
2db018a1d7
commit
5536fb29f2
7 changed files with 209 additions and 130 deletions
|
|
@ -4090,6 +4090,7 @@ Current version indicated by LITEVER below.
|
|||
const koboldcpp_savedata_save_endpoint = "/api/extra/data/save";
|
||||
const koboldcpp_savedata_load_endpoint = "/api/extra/data/load";
|
||||
const koboldcpp_mcp_endpoint = "/mcp";
|
||||
const koboldcpp_voices_endpoint = "/speakers_list";
|
||||
|
||||
const oai_models_endpoint = "/models";
|
||||
const oai_submit_endpoint = "/completions";
|
||||
|
|
@ -4317,7 +4318,6 @@ Current version indicated by LITEVER below.
|
|||
var websearch_in_progress = false;
|
||||
var kcpp_tts_json = "";
|
||||
var avoidwelcome = false;
|
||||
var voicecloneb64 = "";
|
||||
|
||||
var localsettings = {
|
||||
my_api_key: "0000000000", //put here so it can be saved and loaded in persistent mode
|
||||
|
|
@ -5481,9 +5481,8 @@ Current version indicated by LITEVER below.
|
|||
indexeddb_load("savedcustomcss",""),
|
||||
indexeddb_load("savedusermod",""),
|
||||
indexeddb_load("usermodprops",""),
|
||||
indexeddb_load("samplerpresets",""),
|
||||
indexeddb_load("voiceclone","")
|
||||
]).then(([loadedsettingsjson, loadedstorycompressed, loadedbackgroundimg, currcss, currmod, modpropsstr, loadedsamplerpresetsjson, loadedvoiceclone]) => {
|
||||
indexeddb_load("samplerpresets","")
|
||||
]).then(([loadedsettingsjson, loadedstorycompressed, loadedbackgroundimg, currcss, currmod, modpropsstr, loadedsamplerpresetsjson]) => {
|
||||
try
|
||||
{
|
||||
if (loadedsettingsjson != null && loadedsettingsjson != "" && loadedstorycompressed != null && loadedstorycompressed != "") {
|
||||
|
|
@ -5531,10 +5530,6 @@ Current version indicated by LITEVER below.
|
|||
document.getElementById("enhancedchatinterface").classList.add("transparentbg");
|
||||
document.getElementById("enhancedchatinterface_inner").classList.add("transparentbg");
|
||||
}
|
||||
if(loadedvoiceclone && loadedvoiceclone!="")
|
||||
{
|
||||
voicecloneb64 = loadedvoiceclone;
|
||||
}
|
||||
loadok = true;
|
||||
} else {
|
||||
console.log("Skipped missing local save");
|
||||
|
|
@ -7877,6 +7872,10 @@ Current version indicated by LITEVER below.
|
|||
{
|
||||
fetch_xtts_voices(true,localsettings.tts_mode==XTTS_ID);
|
||||
}
|
||||
else if(localsettings.tts_mode==KCPP_TTS_ID)
|
||||
{
|
||||
fetch_kcpp_voices();
|
||||
}
|
||||
if(localsettings.generate_images_mode==2)
|
||||
{
|
||||
connect_to_a1111(true);
|
||||
|
|
@ -17391,35 +17390,6 @@ Current version indicated by LITEVER below.
|
|||
},true,true);
|
||||
}
|
||||
|
||||
function set_voice_clone()
|
||||
{
|
||||
let finput = document.getElementById('addimgfileinput');
|
||||
finput.click();
|
||||
finput.onchange = (event) => {
|
||||
if (event.target.files.length > 0 && event.target.files[0]) {
|
||||
const file = event.target.files[0];
|
||||
const fname = file.name;
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(audio) {
|
||||
let origAudio = audio.target.result;
|
||||
convertAudioToCompressedBase64(origAudio,(newAudio,duration)=>{
|
||||
indexeddb_save("voiceclone", newAudio);
|
||||
voicecloneb64 = newAudio;
|
||||
adjust_kcpptts_controls();
|
||||
},64);
|
||||
}
|
||||
reader.readAsDataURL(file);
|
||||
}
|
||||
finput.value = "";
|
||||
};
|
||||
}
|
||||
function clear_voice_clone()
|
||||
{
|
||||
indexeddb_save("voiceclone", "");
|
||||
voicecloneb64 = "";
|
||||
adjust_kcpptts_controls();
|
||||
}
|
||||
|
||||
function restore_retried_text()
|
||||
{
|
||||
if(retry_in_progress)
|
||||
|
|
@ -17552,8 +17522,6 @@ Current version indicated by LITEVER below.
|
|||
indexeddb_save("savedusermod","");
|
||||
indexeddb_save("usermodprops","");
|
||||
indexeddb_save("savedcustomcss", "");
|
||||
indexeddb_save("voiceclone", "");
|
||||
voicecloneb64 = "";
|
||||
let styleElement = document.getElementById('custom_css');
|
||||
styleElement.innerHTML = "";
|
||||
show_welcome_panel();
|
||||
|
|
@ -18475,6 +18443,28 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
}
|
||||
|
||||
function fetch_kcpp_voices()
|
||||
{
|
||||
fetch(apply_proxy_url(custom_kobold_endpoint + koboldcpp_voices_endpoint), {
|
||||
method: 'GET', // or 'PUT'
|
||||
headers: get_kobold_header(),
|
||||
})
|
||||
.then(x => x.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
let dropdown = document.getElementById("kcpp_tts_voice");
|
||||
let selectionhtml = ``;
|
||||
for (var i = 0; i < data.length; ++i) {
|
||||
// Check for XTTS voices if set
|
||||
let sel = (localsettings.kcpp_tts_voice!=""&&localsettings.kcpp_tts_voice==data[i]);
|
||||
selectionhtml += `<option value="` + data[i] + `"`+(sel?" selected":"")+`>`+data[i]+`</option>`;
|
||||
}
|
||||
selectionhtml += `<option value="custom">custom</option><option value="voicejson">voicejson</option>`;
|
||||
dropdown.innerHTML = selectionhtml;
|
||||
}).catch((error) => {
|
||||
});
|
||||
}
|
||||
|
||||
function manual_tts()
|
||||
{
|
||||
let ssval = localsettings.tts_mode;
|
||||
|
|
@ -18515,7 +18505,7 @@ Current version indicated by LITEVER below.
|
|||
if (userinput != null && userinput!="" && ssval > 0) {
|
||||
tts_speak(userinput,downloadtts,embedtts,false);
|
||||
}
|
||||
},true);
|
||||
},true,true);
|
||||
}
|
||||
|
||||
function test_tts()
|
||||
|
|
@ -18568,6 +18558,7 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
else if(selectedTTS == KCPP_TTS_ID) {
|
||||
document.getElementById("kcpp_tts_container").classList.remove("hidden");
|
||||
fetch_kcpp_voices();
|
||||
if(is_using_kcpp_with_tts())
|
||||
{
|
||||
document.getElementById("nokcpptts").classList.add("hidden");
|
||||
|
|
@ -18636,19 +18627,6 @@ Current version indicated by LITEVER below.
|
|||
document.getElementById("kcpp_tts_voice_json").classList.add("hidden");
|
||||
}
|
||||
|
||||
document.getElementById("kcpp_tts_voice_clone").classList.add("hidden");
|
||||
document.getElementById("kcpp_tts_voice_clone_clear").classList.add("hidden");
|
||||
if (document.getElementById("kcpp_tts_voice").value == "voiceclone") {
|
||||
if(voicecloneb64=="")
|
||||
{
|
||||
document.getElementById("kcpp_tts_voice_clone").classList.remove("hidden");
|
||||
}
|
||||
else
|
||||
{
|
||||
document.getElementById("kcpp_tts_voice_clone_clear").classList.remove("hidden");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Update set_xtts_url to use the new fetch_rvc_voices function
|
||||
|
|
@ -18830,7 +18808,6 @@ Current version indicated by LITEVER below.
|
|||
} else {
|
||||
sub_endpt = apply_proxy_url(custom_kobold_endpoint + koboldcpp_tts_endpoint);
|
||||
let is_voicejson = (document.getElementById("kcpp_tts_voice").value == "voicejson");
|
||||
let is_voiceclone = (document.getElementById("kcpp_tts_voice").value == "voiceclone");
|
||||
let is_custom = (document.getElementById("kcpp_tts_voice").value == "custom");
|
||||
payload =
|
||||
{
|
||||
|
|
@ -18841,10 +18818,6 @@ Current version indicated by LITEVER below.
|
|||
{
|
||||
payload.speaker_json = vcjson;
|
||||
}
|
||||
if(is_voiceclone && voicecloneb64!="")
|
||||
{
|
||||
payload.reference_audio = voicecloneb64;
|
||||
}
|
||||
ttsheaders = get_kobold_header();
|
||||
}
|
||||
|
||||
|
|
@ -29849,14 +29822,11 @@ Current version indicated by LITEVER below.
|
|||
<option value="chatty">chatty</option>
|
||||
<option value="custom">custom</option>
|
||||
<option value="voicejson">voicejson</option>
|
||||
<option value="voiceclone">voiceclone</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<input type="text" value="" placeholder="(Name)" id="kcpp_tts_voice_custom" style="margin-left:3px; width:56px;">
|
||||
<button id="kcpp_tts_voice_json" type="button" class="btn btn-primary" style="margin-left:3px; width:56px;" onclick="set_voice_json()">Setup</button>
|
||||
<button id="kcpp_tts_voice_clone" type="button" class="btn btn-primary" style="margin-left:3px; width:56px;" onclick="set_voice_clone()">Load</button>
|
||||
<button id="kcpp_tts_voice_clone_clear" type="button" class="btn btn-primary bg_red" style="margin-left:3px; width:56px;" onclick="clear_voice_clone()">Clear</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
59
koboldcpp.py
59
koboldcpp.py
|
|
@ -115,6 +115,7 @@ importvars_in_progress = False
|
|||
has_multiplayer = False
|
||||
has_audio_support = False
|
||||
has_vision_support = False
|
||||
has_whisper = False
|
||||
cached_chat_template = None
|
||||
savedata_obj = None
|
||||
mcp_connections = [] #every element is linked to one mcp source, contains obj {"client":obj, "tools":[]}
|
||||
|
|
@ -136,6 +137,8 @@ embedded_kcpp_docs_gz = None
|
|||
embedded_kcpp_sdui = None
|
||||
embedded_kcpp_sdui_gz = None
|
||||
embedded_lcpp_ui_gz = None
|
||||
voicebank = {}
|
||||
voicelist = ["kobo","cheery","sleepy","shouty","chatty"]
|
||||
sslvalid = False
|
||||
nocertify = False
|
||||
start_time = time.time()
|
||||
|
|
@ -2245,15 +2248,14 @@ def tts_prepare_voice_json(jsonstr):
|
|||
return None
|
||||
|
||||
def tts_generate(genparams):
|
||||
global args
|
||||
global args, voicebank, voicelist
|
||||
prompt = genparams.get("input", genparams.get("text", ""))
|
||||
prompt = prompt.strip()
|
||||
voice = 1
|
||||
speaker_json = tts_prepare_voice_json(genparams.get("speaker_json","")) #handle custom json voices
|
||||
reference_audio = genparams.get("reference_audio","") #for cloned voices in qwen3tts
|
||||
voicestr = genparams.get("voice", genparams.get("speaker_wav", ""))
|
||||
oai_voicemap = ["alloy","onyx","echo","nova","shimmer"] # map to kcpp defaults
|
||||
voice_mapping = ["kobo","cheery","sleepy","shouty","chatty"]
|
||||
voice_mapping = voicelist
|
||||
normalized_voice = voicestr.strip().lower() if voicestr else ""
|
||||
if normalized_voice.endswith(".wav"):
|
||||
normalized_voice = normalized_voice[:-4]
|
||||
|
|
@ -2280,6 +2282,7 @@ def tts_generate(genparams):
|
|||
else:
|
||||
inputs.custom_speaker_text = "".encode("UTF-8")
|
||||
inputs.custom_speaker_data = "".encode("UTF-8")
|
||||
reference_audio = voicebank.get(voicestr,"") #for cloned voices in qwen3tts
|
||||
if reference_audio and reference_audio.startswith("data:audio"):
|
||||
reference_audio = reference_audio.split(",", 1)[1]
|
||||
inputs.reference_audio = reference_audio.encode("UTF-8")
|
||||
|
|
@ -3856,7 +3859,7 @@ Change Mode<br>
|
|||
def do_GET(self):
|
||||
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui, embedded_kailite_gz, embedded_kcpp_docs_gz, embedded_kcpp_sdui_gz, embedded_lcpp_ui_gz
|
||||
global last_req_time, start_time, cached_chat_template, has_vision_support, has_audio_support, has_whisper, friendlymodelname
|
||||
global savedata_obj, has_multiplayer, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, maxctx, maxhordelen, friendlymodelname, lastuploadedcomfyimg, lastgeneratedcomfyimg, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, password, friendlyembeddingsmodelname
|
||||
global savedata_obj, has_multiplayer, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, maxctx, maxhordelen, friendlymodelname, lastuploadedcomfyimg, lastgeneratedcomfyimg, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, password, friendlyembeddingsmodelname, voicelist
|
||||
|
||||
clean_path = self.path.split("?")[0] #for cases where we do not want query params
|
||||
if clean_path=="/lcpp": #fix for svelte redirect issues, browser path needs to end with slash
|
||||
|
|
@ -4043,11 +4046,14 @@ Change Mode<br>
|
|||
pass
|
||||
|
||||
elif clean_path.endswith('/speakers_list'): #xtts compatible
|
||||
response_body = (json.dumps(["kobo","cheery","sleepy","shouty","chatty"]).encode()) #some random voices for them to enjoy
|
||||
response_body = (json.dumps(voicelist).encode()) #some random voices for them to enjoy
|
||||
elif clean_path.endswith('/speakers'): #xtts compatible
|
||||
response_body = (json.dumps([{"name":"kobo","voice_id":"kobo","preview_url":""},{"name":"cheery","voice_id":"cheery","preview_url":""},{"name":"sleepy","voice_id":"sleepy","preview_url":""},{"name":"shouty","voice_id":"shouty","preview_url":""},{"name":"chatty","voice_id":"chatty","preview_url":""}]).encode()) #some random voices for them to enjoy
|
||||
tmplist = []
|
||||
for itm in voicelist:
|
||||
tmplist.append({"name":itm,"voice_id":itm,"preview_url":""})
|
||||
response_body = (json.dumps(tmplist).encode()) #some random voices for them to enjoy
|
||||
elif clean_path.endswith('/v1/audio/voices') or clean_path=='/audio/voices':
|
||||
response_body = (json.dumps({"status":"ok","voices":["kobo","cheery","sleepy","shouty","chatty"]}).encode()) #some random voices for them to enjoy
|
||||
response_body = (json.dumps({"status":"ok","voices":voicelist}).encode()) #some random voices for them to enjoy
|
||||
elif clean_path.endswith('/get_tts_settings'): #xtts compatible
|
||||
response_body = (json.dumps({"temperature":0.75,"speed":1,"length_penalty":1,"repetition_penalty":1,"top_p":1,"top_k":4,"enable_text_splitting":True,"stream_chunk_size":100}).encode()) #some random voices for them to enjoy
|
||||
|
||||
|
|
@ -5662,6 +5668,7 @@ def show_gui():
|
|||
ttsgpu_var = ctk.IntVar(value=0)
|
||||
tts_threads_var = ctk.StringVar(value=str(default_threads))
|
||||
ttsmaxlen_var = ctk.StringVar(value=str(default_ttsmaxlen))
|
||||
tts_dir_var = ctk.StringVar()
|
||||
|
||||
embeddings_model_var = ctk.StringVar()
|
||||
embeddings_ctx_var = ctk.StringVar(value=str(""))
|
||||
|
|
@ -6459,6 +6466,8 @@ def show_gui():
|
|||
ttsgpu_var.trace_add("write", gui_changed_modelfile)
|
||||
makefileentry(audio_tab, "WavTokenizer Model (Required for some models):", "Select WavTokenizer GGUF Model File", wavtokenizer_var, 11, width=280, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select a WavTokenizer GGUF model file on disk to be loaded for Narration.")
|
||||
wavtokenizer_var.trace_add("write", gui_changed_modelfile)
|
||||
makefileentry(audio_tab, "TTS Voices Dir:", "Select directory containing voices for voice cloning", tts_dir_var, 20, width=280, singlerow=True, dialog_type=2, tooltiptxt="Select directory containing voices for voice cloning")
|
||||
|
||||
|
||||
admin_tab = tabcontent["Admin"]
|
||||
def toggleadmin(a,b,c):
|
||||
|
|
@ -6771,6 +6780,7 @@ def show_gui():
|
|||
args.ttswavtokenizer = wavtokenizer_var.get()
|
||||
args.ttsgpu = (ttsgpu_var.get()==1)
|
||||
args.ttsmaxlen = (default_ttsmaxlen if ttsmaxlen_var.get()=="" else int(ttsmaxlen_var.get()))
|
||||
args.ttsdir = tts_dir_var.get()
|
||||
|
||||
args.admin = (admin_var.get()==1 and not args.cli)
|
||||
args.admindir = admin_dir_var.get()
|
||||
|
|
@ -7012,6 +7022,7 @@ def show_gui():
|
|||
wavtokenizer_var.set(dict["ttswavtokenizer"] if ("ttswavtokenizer" in dict and dict["ttswavtokenizer"]) else "")
|
||||
ttsgpu_var.set(dict["ttsgpu"] if ("ttsgpu" in dict) else 0)
|
||||
ttsmaxlen_var.set(str(dict["ttsmaxlen"]) if ("ttsmaxlen" in dict and dict["ttsmaxlen"]) else str(default_ttsmaxlen))
|
||||
tts_dir_var.set(dict["ttsdir"] if ("ttsdir" in dict and dict["ttsdir"]) else "")
|
||||
|
||||
embeddings_model_var.set(dict["embeddingsmodel"] if ("embeddingsmodel" in dict and dict["embeddingsmodel"]) else "")
|
||||
embeddings_ctx_var.set(str(dict["embeddingsmaxctx"]) if ("embeddingsmaxctx" in dict and dict["embeddingsmaxctx"]) else "")
|
||||
|
|
@ -8703,6 +8714,39 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
except Exception:
|
||||
print("Could not find Embedded llama.cpp UI.")
|
||||
|
||||
# load all TTS audio files
|
||||
if args.ttsdir and args.ttsmodel and os.path.isdir(args.ttsdir):
|
||||
try:
|
||||
global voicebank, voicelist
|
||||
voicebank = {}
|
||||
voicecount = 0
|
||||
voicelist = []
|
||||
|
||||
voicelist.append("kobo")
|
||||
voicebank["kobo"] = ""
|
||||
voicelist.append("cheery")
|
||||
voicebank["cheery"] = ""
|
||||
voicelist.append("sleepy")
|
||||
voicebank["sleepy"] = ""
|
||||
voicelist.append("shouty")
|
||||
voicebank["shouty"] = ""
|
||||
voicelist.append("chatty")
|
||||
voicebank["chatty"] = ""
|
||||
voicelist.append("random")
|
||||
voicebank["random"] = ""
|
||||
|
||||
for filename in os.listdir(args.ttsdir):
|
||||
if filename.lower().endswith((".mp3", ".wav")):
|
||||
full_path = os.path.join(args.ttsdir, filename)
|
||||
with open(full_path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("utf-8")
|
||||
voicebank[filename] = encoded
|
||||
voicecount += 1
|
||||
voicelist.append(os.path.basename(filename))
|
||||
print(f"Loaded {voicecount} TTS voices.")
|
||||
except Exception:
|
||||
print("Could not load TTS voices.")
|
||||
|
||||
if args.mcpfile and isinstance(args.mcpfile, str):
|
||||
threading.Thread(target=load_mcp_async, args=(args,), daemon=True).start()
|
||||
time.sleep(0.2) # short delay to allow get_capabilities to work
|
||||
|
|
@ -9065,6 +9109,7 @@ if __name__ == '__main__':
|
|||
ttsparsergroup.add_argument("--ttsgpu", help="Use the GPU for TTS.", action='store_true')
|
||||
ttsparsergroup.add_argument("--ttsmaxlen", help="Limit number of audio tokens generated with TTS.", type=int, default=default_ttsmaxlen)
|
||||
ttsparsergroup.add_argument("--ttsthreads", metavar=('[threads]'), help="Use a different number of threads for TTS if specified. Otherwise, has the same value as --threads.", type=int, default=0)
|
||||
ttsparsergroup.add_argument("--ttsdir", metavar=('[directory]'), help="Select directory containing voices for voice cloning.", default="")
|
||||
|
||||
embeddingsparsergroup = parser.add_argument_group('Embeddings Model Commands')
|
||||
embeddingsparsergroup.add_argument("--embeddingsmodel", metavar=('[filename]'), help="Specify an embeddings model to be loaded for generating embedding vectors.", default="")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,11 @@ Qwen3TTS::Qwen3TTS() = default;
|
|||
|
||||
Qwen3TTS::~Qwen3TTS() = default;
|
||||
|
||||
void Qwen3TTS::set_seed(int seed)
|
||||
{
|
||||
this->transformer_.set_seed(seed);
|
||||
}
|
||||
|
||||
bool Qwen3TTS::load_models(const std::string & model_dir) {
|
||||
// Construct model paths
|
||||
std::string tts_model_path = model_dir + "/qwen3-tts-0.6b-f16.gguf";
|
||||
|
|
@ -197,9 +202,10 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
|
|||
return synthesize_with_voice(text, ref_samples.data(), (int32_t)ref_samples.size(), params);
|
||||
}
|
||||
|
||||
static std::vector<float> speaker_embedding;
|
||||
tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
|
||||
const float * ref_samples, int32_t n_ref_samples,
|
||||
const tts_params & params) {
|
||||
const tts_params & params, bool regenerate) {
|
||||
tts_result result;
|
||||
|
||||
if (!models_loaded_) {
|
||||
|
|
@ -226,11 +232,14 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
|
|||
}
|
||||
|
||||
int64_t t_encode_start = get_time_ms();
|
||||
std::vector<float> speaker_embedding;
|
||||
|
||||
if (!audio_encoder_.encode(ref_samples, n_ref_samples, speaker_embedding)) {
|
||||
result.error_msg = "Failed to extract speaker embedding: " + audio_encoder_.get_error();
|
||||
return result;
|
||||
if(speaker_embedding.size()==0 || regenerate)
|
||||
{
|
||||
speaker_embedding.clear();
|
||||
if (!audio_encoder_.encode(ref_samples, n_ref_samples, speaker_embedding)) {
|
||||
result.error_msg = "Failed to extract speaker embedding: " + audio_encoder_.get_error();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
result.t_encode_ms = get_time_ms() - t_encode_start;
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ public:
|
|||
Qwen3TTS();
|
||||
~Qwen3TTS();
|
||||
|
||||
void set_seed(int seed);
|
||||
|
||||
// Load all models from directory
|
||||
// model_dir should contain: transformer.gguf, tokenizer.gguf, vocoder.gguf
|
||||
bool load_models(const std::string & model_dir);
|
||||
|
|
@ -107,7 +109,7 @@ public:
|
|||
// params: generation parameters
|
||||
tts_result synthesize_with_voice(const std::string & text,
|
||||
const float * ref_samples, int32_t n_ref_samples,
|
||||
const tts_params & params = tts_params());
|
||||
const tts_params & params = tts_params(), bool regenerate=true);
|
||||
|
||||
// Set progress callback
|
||||
void set_progress_callback(tts_progress_callback_t callback);
|
||||
|
|
|
|||
|
|
@ -49,6 +49,15 @@ void TTSTransformer::unload_model() {
|
|||
embd_row_fp16_scratch_.clear();
|
||||
}
|
||||
|
||||
void TTSTransformer::set_seed(int seed)
|
||||
{
|
||||
if (seed <= 0 || seed==0xFFFFFFFF)
|
||||
{
|
||||
seed = (((uint32_t)time(NULL)) % 1000000u);
|
||||
}
|
||||
this->rng_ = std::mt19937(seed);
|
||||
}
|
||||
|
||||
bool TTSTransformer::load_model(const std::string & model_path) {
|
||||
unload_model();
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ struct tts_transformer_config {
|
|||
// Text embedding
|
||||
int32_t text_vocab_size = 151936;
|
||||
int32_t text_embd_dim = 2048;
|
||||
|
||||
|
||||
// Talker transformer
|
||||
int32_t hidden_size = 1024;
|
||||
int32_t n_layers = 28;
|
||||
|
|
@ -69,18 +69,18 @@ struct tts_transformer_config {
|
|||
int32_t head_dim = 128;
|
||||
float rms_norm_eps = 1e-6f;
|
||||
float rope_theta = 1000000.0f;
|
||||
|
||||
|
||||
// M-RoPE sections [time, freq, channel] = [24, 20, 20]
|
||||
int32_t mrope_section[3] = {24, 20, 20};
|
||||
|
||||
|
||||
// Codec vocabulary
|
||||
int32_t codec_vocab_size = 3072; // talker.codec_embd/codec_head
|
||||
int32_t n_codebooks = 16;
|
||||
|
||||
|
||||
// Code predictor
|
||||
int32_t code_pred_layers = 5;
|
||||
int32_t code_pred_vocab_size = 2048; // Per-codebook vocab
|
||||
|
||||
|
||||
// Special codec tokens
|
||||
int32_t codec_pad_id = 2148;
|
||||
int32_t codec_bos_id = 2149;
|
||||
|
|
@ -101,16 +101,16 @@ struct tts_transformer_config {
|
|||
// Transformer layer weights
|
||||
struct transformer_layer {
|
||||
struct ggml_tensor * attn_norm = nullptr;
|
||||
|
||||
|
||||
struct ggml_tensor * attn_q = nullptr;
|
||||
struct ggml_tensor * attn_k = nullptr;
|
||||
struct ggml_tensor * attn_v = nullptr;
|
||||
struct ggml_tensor * attn_output = nullptr;
|
||||
struct ggml_tensor * attn_q_norm = nullptr;
|
||||
struct ggml_tensor * attn_k_norm = nullptr;
|
||||
|
||||
|
||||
struct ggml_tensor * ffn_norm = nullptr;
|
||||
|
||||
|
||||
struct ggml_tensor * ffn_gate = nullptr;
|
||||
struct ggml_tensor * ffn_up = nullptr;
|
||||
struct ggml_tensor * ffn_down = nullptr;
|
||||
|
|
@ -119,42 +119,42 @@ struct transformer_layer {
|
|||
// TTS Transformer model weights
|
||||
struct tts_transformer_model {
|
||||
tts_transformer_config config;
|
||||
|
||||
|
||||
// Text embedding and projection
|
||||
struct ggml_tensor * text_embd = nullptr; // [text_embd_dim, text_vocab_size]
|
||||
struct ggml_tensor * text_proj_fc1 = nullptr; // [text_embd_dim, text_embd_dim]
|
||||
struct ggml_tensor * text_proj_fc1_bias = nullptr;
|
||||
struct ggml_tensor * text_proj_fc2 = nullptr; // [text_embd_dim, hidden_size]
|
||||
struct ggml_tensor * text_proj_fc2_bias = nullptr;
|
||||
|
||||
|
||||
// Codec embedding (for autoregressive input)
|
||||
struct ggml_tensor * codec_embd = nullptr; // [hidden_size, codec_vocab_size]
|
||||
|
||||
|
||||
// Talker transformer layers
|
||||
std::vector<transformer_layer> layers;
|
||||
|
||||
|
||||
// Final RMSNorm
|
||||
struct ggml_tensor * output_norm = nullptr; // [hidden_size]
|
||||
|
||||
|
||||
// Codec head (for first codebook prediction)
|
||||
struct ggml_tensor * codec_head = nullptr; // [hidden_size, codec_vocab_size]
|
||||
|
||||
|
||||
// Code predictor layers
|
||||
std::vector<transformer_layer> code_pred_layers;
|
||||
|
||||
|
||||
// Code predictor output norm (final RMS norm before lm_head)
|
||||
struct ggml_tensor * code_pred_output_norm = nullptr; // [hidden_size]
|
||||
|
||||
|
||||
// Code predictor per-codebook embeddings and heads (15 codebooks, 0 uses talker output)
|
||||
std::vector<struct ggml_tensor *> code_pred_embd; // [hidden_size, code_pred_vocab_size] x 15
|
||||
std::vector<struct ggml_tensor *> code_pred_head; // [hidden_size, code_pred_vocab_size] x 15
|
||||
|
||||
|
||||
// GGML context for tensor metadata
|
||||
struct ggml_context * ctx = nullptr;
|
||||
|
||||
|
||||
// Backend buffer for weights
|
||||
ggml_backend_buffer_t buffer = nullptr;
|
||||
|
||||
|
||||
// Tensor name to tensor mapping
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
};
|
||||
|
|
@ -163,10 +163,10 @@ struct tts_transformer_model {
|
|||
struct tts_kv_cache {
|
||||
std::vector<struct ggml_tensor *> k_cache;
|
||||
std::vector<struct ggml_tensor *> v_cache;
|
||||
|
||||
|
||||
struct ggml_context * ctx = nullptr;
|
||||
ggml_backend_buffer_t buffer = nullptr;
|
||||
|
||||
|
||||
int32_t n_ctx = 0;
|
||||
int32_t n_used = 0;
|
||||
int32_t head_dim = 128;
|
||||
|
|
@ -179,9 +179,9 @@ struct tts_transformer_state {
|
|||
ggml_backend_t backend = nullptr;
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
|
||||
|
||||
std::vector<uint8_t> compute_meta;
|
||||
|
||||
|
||||
tts_kv_cache cache; // Talker KV cache (28 layers)
|
||||
tts_kv_cache code_pred_cache; // Code predictor KV cache (5 layers)
|
||||
};
|
||||
|
|
@ -191,25 +191,27 @@ class TTSTransformer {
|
|||
public:
|
||||
TTSTransformer();
|
||||
~TTSTransformer();
|
||||
|
||||
|
||||
void set_seed(int seed);
|
||||
|
||||
// Load model from GGUF file
|
||||
bool load_model(const std::string & model_path);
|
||||
|
||||
// Release all model/runtime resources
|
||||
void unload_model();
|
||||
|
||||
|
||||
// Initialize KV cache
|
||||
bool init_kv_cache(int32_t n_ctx);
|
||||
|
||||
|
||||
// Clear KV cache
|
||||
void clear_kv_cache();
|
||||
|
||||
|
||||
// Initialize code predictor KV cache (5 layers, max 16 context)
|
||||
bool init_code_pred_kv_cache(int32_t n_ctx);
|
||||
|
||||
|
||||
// Clear code predictor KV cache
|
||||
void clear_code_pred_kv_cache();
|
||||
|
||||
|
||||
// Forward pass for text tokens (prefill phase)
|
||||
// text_tokens: input text token IDs [n_tokens]
|
||||
// speaker_embd: speaker embedding [hidden_size] (optional, can be nullptr)
|
||||
|
|
@ -222,7 +224,7 @@ public:
|
|||
bool forward_prefill(const float * prefill_embd, int32_t n_tokens,
|
||||
int32_t n_past, std::vector<float> & output,
|
||||
std::vector<float> * logits_out = nullptr);
|
||||
|
||||
|
||||
// Forward pass for codec tokens (generation phase)
|
||||
// codec_token: single codec token for first codebook
|
||||
// n_past: number of tokens already in KV cache
|
||||
|
|
@ -233,26 +235,26 @@ public:
|
|||
bool forward_step(const float * step_embd, int32_t n_past,
|
||||
std::vector<float> & output,
|
||||
std::vector<float> * hidden_out = nullptr);
|
||||
|
||||
|
||||
// Get hidden states from last forward pass (for code predictor)
|
||||
bool get_hidden_states(std::vector<float> & hidden) const;
|
||||
|
||||
|
||||
// Run code predictor to get all 16 codebook predictions
|
||||
// hidden: hidden states from talker [hidden_size]
|
||||
// prev_codes: previous codes for codebooks 1-15 (can be nullptr for first step)
|
||||
// output: logits for all 16 codebooks [16, code_pred_vocab_size]
|
||||
bool predict_codes(const float * hidden, const int32_t * prev_codes,
|
||||
std::vector<float> & output);
|
||||
|
||||
|
||||
// Run code predictor autoregressively to generate 15 codes (codebooks 1-15)
|
||||
// hidden: hidden states from talker [hidden_size]
|
||||
// codebook_0_token: the codebook 0 token (used to create 2-token prefill input)
|
||||
// output: generated codes for codebooks 1-15 [15]
|
||||
bool predict_codes_autoregressive(const float * hidden, int32_t codebook_0_token,
|
||||
bool predict_codes_autoregressive(const float * hidden, int32_t codebook_0_token,
|
||||
std::vector<int32_t> & output,
|
||||
float temperature = 0.9f,
|
||||
int32_t top_k = 50);
|
||||
|
||||
|
||||
// Generate speech codes autoregressively
|
||||
// text_tokens: input text token IDs [n_tokens]
|
||||
// speaker_embd: speaker embedding [hidden_size]
|
||||
|
|
@ -265,20 +267,20 @@ public:
|
|||
float repetition_penalty = 1.05f,
|
||||
float temperature = 0.9f,
|
||||
int32_t top_k = 50);
|
||||
|
||||
|
||||
const tts_transformer_config & get_config() const { return model_.config; }
|
||||
|
||||
|
||||
const std::string & get_error() const { return error_msg_; }
|
||||
|
||||
|
||||
// Legacy interface for compatibility
|
||||
bool forward(const int32_t * tokens, int32_t n_tokens, int32_t n_past,
|
||||
std::vector<float> & output);
|
||||
|
||||
|
||||
bool forward_with_audio(const int32_t * tokens, int32_t n_tokens,
|
||||
const float * audio_embd, int32_t n_audio,
|
||||
int32_t audio_start_pos, int32_t n_past,
|
||||
std::vector<float> & output);
|
||||
|
||||
|
||||
private:
|
||||
bool try_init_coreml_code_predictor(const std::string & model_path);
|
||||
bool predict_codes_autoregressive_coreml(const float * hidden, int32_t codebook_0_token,
|
||||
|
|
@ -304,32 +306,32 @@ private:
|
|||
const char * output_name, std::vector<float> & output);
|
||||
bool lookup_single_embedding_row(struct ggml_tensor * embedding, int32_t token_id,
|
||||
float * out_row);
|
||||
|
||||
|
||||
// Build computation graph for code predictor
|
||||
struct ggml_cgraph * build_code_pred_graph(int32_t n_prev_codes);
|
||||
|
||||
|
||||
// Build computation graph for single-step autoregressive code predictor
|
||||
// n_past: number of tokens already in KV cache (0-14)
|
||||
// generation_step: which codebook we're predicting (0-14)
|
||||
struct ggml_cgraph * build_code_pred_step_graph(int32_t n_past, int32_t generation_step);
|
||||
|
||||
|
||||
// Build computation graph for 2-token prefill of code predictor
|
||||
// Processes [past_hidden, codec_embd(codebook_0_token)] together
|
||||
struct ggml_cgraph * build_code_pred_prefill_graph();
|
||||
|
||||
|
||||
// Parse hyperparameters from GGUF
|
||||
bool parse_config(struct gguf_context * ctx);
|
||||
|
||||
|
||||
// Create tensor structures
|
||||
bool create_tensors(struct gguf_context * ctx);
|
||||
|
||||
|
||||
// Load tensor data from file
|
||||
bool load_tensor_data(const std::string & path, struct gguf_context * ctx);
|
||||
|
||||
|
||||
tts_transformer_model model_;
|
||||
tts_transformer_state state_;
|
||||
std::string error_msg_;
|
||||
|
||||
|
||||
// Cached hidden states from last forward pass
|
||||
std::vector<float> last_hidden_;
|
||||
std::vector<ggml_fp16_t> embd_row_fp16_scratch_;
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue