add some default voices for qwen3tts

This commit is contained in:
Concedo 2026-02-21 23:45:15 +08:00
parent 2db018a1d7
commit 5536fb29f2
7 changed files with 209 additions and 130 deletions

View file

@ -4090,6 +4090,7 @@ Current version indicated by LITEVER below.
const koboldcpp_savedata_save_endpoint = "/api/extra/data/save";
const koboldcpp_savedata_load_endpoint = "/api/extra/data/load";
const koboldcpp_mcp_endpoint = "/mcp";
const koboldcpp_voices_endpoint = "/speakers_list";
const oai_models_endpoint = "/models";
const oai_submit_endpoint = "/completions";
@ -4317,7 +4318,6 @@ Current version indicated by LITEVER below.
var websearch_in_progress = false;
var kcpp_tts_json = "";
var avoidwelcome = false;
var voicecloneb64 = "";
var localsettings = {
my_api_key: "0000000000", //put here so it can be saved and loaded in persistent mode
@ -5481,9 +5481,8 @@ Current version indicated by LITEVER below.
indexeddb_load("savedcustomcss",""),
indexeddb_load("savedusermod",""),
indexeddb_load("usermodprops",""),
indexeddb_load("samplerpresets",""),
indexeddb_load("voiceclone","")
]).then(([loadedsettingsjson, loadedstorycompressed, loadedbackgroundimg, currcss, currmod, modpropsstr, loadedsamplerpresetsjson, loadedvoiceclone]) => {
indexeddb_load("samplerpresets","")
]).then(([loadedsettingsjson, loadedstorycompressed, loadedbackgroundimg, currcss, currmod, modpropsstr, loadedsamplerpresetsjson]) => {
try
{
if (loadedsettingsjson != null && loadedsettingsjson != "" && loadedstorycompressed != null && loadedstorycompressed != "") {
@ -5531,10 +5530,6 @@ Current version indicated by LITEVER below.
document.getElementById("enhancedchatinterface").classList.add("transparentbg");
document.getElementById("enhancedchatinterface_inner").classList.add("transparentbg");
}
if(loadedvoiceclone && loadedvoiceclone!="")
{
voicecloneb64 = loadedvoiceclone;
}
loadok = true;
} else {
console.log("Skipped missing local save");
@ -7877,6 +7872,10 @@ Current version indicated by LITEVER below.
{
fetch_xtts_voices(true,localsettings.tts_mode==XTTS_ID);
}
else if(localsettings.tts_mode==KCPP_TTS_ID)
{
fetch_kcpp_voices();
}
if(localsettings.generate_images_mode==2)
{
connect_to_a1111(true);
@ -17391,35 +17390,6 @@ Current version indicated by LITEVER below.
},true,true);
}
function set_voice_clone()
{
let finput = document.getElementById('addimgfileinput');
finput.click();
finput.onchange = (event) => {
if (event.target.files.length > 0 && event.target.files[0]) {
const file = event.target.files[0];
const fname = file.name;
const reader = new FileReader();
reader.onload = function(audio) {
let origAudio = audio.target.result;
convertAudioToCompressedBase64(origAudio,(newAudio,duration)=>{
indexeddb_save("voiceclone", newAudio);
voicecloneb64 = newAudio;
adjust_kcpptts_controls();
},64);
}
reader.readAsDataURL(file);
}
finput.value = "";
};
}
function clear_voice_clone()
{
indexeddb_save("voiceclone", "");
voicecloneb64 = "";
adjust_kcpptts_controls();
}
function restore_retried_text()
{
if(retry_in_progress)
@ -17552,8 +17522,6 @@ Current version indicated by LITEVER below.
indexeddb_save("savedusermod","");
indexeddb_save("usermodprops","");
indexeddb_save("savedcustomcss", "");
indexeddb_save("voiceclone", "");
voicecloneb64 = "";
let styleElement = document.getElementById('custom_css');
styleElement.innerHTML = "";
show_welcome_panel();
@ -18475,6 +18443,28 @@ Current version indicated by LITEVER below.
}
}
function fetch_kcpp_voices()
{
fetch(apply_proxy_url(custom_kobold_endpoint + koboldcpp_voices_endpoint), {
method: 'GET', // or 'PUT'
headers: get_kobold_header(),
})
.then(x => x.json())
.then(data => {
console.log(data);
let dropdown = document.getElementById("kcpp_tts_voice");
let selectionhtml = ``;
for (var i = 0; i < data.length; ++i) {
// Check for XTTS voices if set
let sel = (localsettings.kcpp_tts_voice!=""&&localsettings.kcpp_tts_voice==data[i]);
selectionhtml += `<option value="` + data[i] + `"`+(sel?" selected":"")+`>`+data[i]+`</option>`;
}
selectionhtml += `<option value="custom">custom</option><option value="voicejson">voicejson</option>`;
dropdown.innerHTML = selectionhtml;
}).catch((error) => {
});
}
function manual_tts()
{
let ssval = localsettings.tts_mode;
@ -18515,7 +18505,7 @@ Current version indicated by LITEVER below.
if (userinput != null && userinput!="" && ssval > 0) {
tts_speak(userinput,downloadtts,embedtts,false);
}
},true);
},true,true);
}
function test_tts()
@ -18568,6 +18558,7 @@ Current version indicated by LITEVER below.
}
else if(selectedTTS == KCPP_TTS_ID) {
document.getElementById("kcpp_tts_container").classList.remove("hidden");
fetch_kcpp_voices();
if(is_using_kcpp_with_tts())
{
document.getElementById("nokcpptts").classList.add("hidden");
@ -18636,19 +18627,6 @@ Current version indicated by LITEVER below.
document.getElementById("kcpp_tts_voice_json").classList.add("hidden");
}
document.getElementById("kcpp_tts_voice_clone").classList.add("hidden");
document.getElementById("kcpp_tts_voice_clone_clear").classList.add("hidden");
if (document.getElementById("kcpp_tts_voice").value == "voiceclone") {
if(voicecloneb64=="")
{
document.getElementById("kcpp_tts_voice_clone").classList.remove("hidden");
}
else
{
document.getElementById("kcpp_tts_voice_clone_clear").classList.remove("hidden");
}
}
}
// Update set_xtts_url to use the new fetch_rvc_voices function
@ -18830,7 +18808,6 @@ Current version indicated by LITEVER below.
} else {
sub_endpt = apply_proxy_url(custom_kobold_endpoint + koboldcpp_tts_endpoint);
let is_voicejson = (document.getElementById("kcpp_tts_voice").value == "voicejson");
let is_voiceclone = (document.getElementById("kcpp_tts_voice").value == "voiceclone");
let is_custom = (document.getElementById("kcpp_tts_voice").value == "custom");
payload =
{
@ -18841,10 +18818,6 @@ Current version indicated by LITEVER below.
{
payload.speaker_json = vcjson;
}
if(is_voiceclone && voicecloneb64!="")
{
payload.reference_audio = voicecloneb64;
}
ttsheaders = get_kobold_header();
}
@ -29849,14 +29822,11 @@ Current version indicated by LITEVER below.
<option value="chatty">chatty</option>
<option value="custom">custom</option>
<option value="voicejson">voicejson</option>
<option value="voiceclone">voiceclone</option>
</select>
</div>
<div>
<input type="text" value="" placeholder="(Name)" id="kcpp_tts_voice_custom" style="margin-left:3px; width:56px;">
<button id="kcpp_tts_voice_json" type="button" class="btn btn-primary" style="margin-left:3px; width:56px;" onclick="set_voice_json()">Setup</button>
<button id="kcpp_tts_voice_clone" type="button" class="btn btn-primary" style="margin-left:3px; width:56px;" onclick="set_voice_clone()">Load</button>
<button id="kcpp_tts_voice_clone_clear" type="button" class="btn btn-primary bg_red" style="margin-left:3px; width:56px;" onclick="clear_voice_clone()">Clear</button>
</div>
</div>
</div>

View file

@ -115,6 +115,7 @@ importvars_in_progress = False
has_multiplayer = False
has_audio_support = False
has_vision_support = False
has_whisper = False
cached_chat_template = None
savedata_obj = None
mcp_connections = [] #every element is linked to one mcp source, contains obj {"client":obj, "tools":[]}
@ -136,6 +137,8 @@ embedded_kcpp_docs_gz = None
embedded_kcpp_sdui = None
embedded_kcpp_sdui_gz = None
embedded_lcpp_ui_gz = None
voicebank = {}
voicelist = ["kobo","cheery","sleepy","shouty","chatty"]
sslvalid = False
nocertify = False
start_time = time.time()
@ -2245,15 +2248,14 @@ def tts_prepare_voice_json(jsonstr):
return None
def tts_generate(genparams):
global args
global args, voicebank, voicelist
prompt = genparams.get("input", genparams.get("text", ""))
prompt = prompt.strip()
voice = 1
speaker_json = tts_prepare_voice_json(genparams.get("speaker_json","")) #handle custom json voices
reference_audio = genparams.get("reference_audio","") #for cloned voices in qwen3tts
voicestr = genparams.get("voice", genparams.get("speaker_wav", ""))
oai_voicemap = ["alloy","onyx","echo","nova","shimmer"] # map to kcpp defaults
voice_mapping = ["kobo","cheery","sleepy","shouty","chatty"]
voice_mapping = voicelist
normalized_voice = voicestr.strip().lower() if voicestr else ""
if normalized_voice.endswith(".wav"):
normalized_voice = normalized_voice[:-4]
@ -2280,6 +2282,7 @@ def tts_generate(genparams):
else:
inputs.custom_speaker_text = "".encode("UTF-8")
inputs.custom_speaker_data = "".encode("UTF-8")
reference_audio = voicebank.get(voicestr,"") #for cloned voices in qwen3tts
if reference_audio and reference_audio.startswith("data:audio"):
reference_audio = reference_audio.split(",", 1)[1]
inputs.reference_audio = reference_audio.encode("UTF-8")
@ -3856,7 +3859,7 @@ Change Mode<br>
def do_GET(self):
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui, embedded_kailite_gz, embedded_kcpp_docs_gz, embedded_kcpp_sdui_gz, embedded_lcpp_ui_gz
global last_req_time, start_time, cached_chat_template, has_vision_support, has_audio_support, has_whisper, friendlymodelname
global savedata_obj, has_multiplayer, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, maxctx, maxhordelen, friendlymodelname, lastuploadedcomfyimg, lastgeneratedcomfyimg, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, password, friendlyembeddingsmodelname
global savedata_obj, has_multiplayer, multiplayer_turn_major, multiplayer_turn_minor, multiplayer_story_data_compressed, multiplayer_dataformat, multiplayer_lastactive, maxctx, maxhordelen, friendlymodelname, lastuploadedcomfyimg, lastgeneratedcomfyimg, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, password, friendlyembeddingsmodelname, voicelist
clean_path = self.path.split("?")[0] #for cases where we do not want query params
if clean_path=="/lcpp": #fix for svelte redirect issues, browser path needs to end with slash
@ -4043,11 +4046,14 @@ Change Mode<br>
pass
elif clean_path.endswith('/speakers_list'): #xtts compatible
response_body = (json.dumps(["kobo","cheery","sleepy","shouty","chatty"]).encode()) #some random voices for them to enjoy
response_body = (json.dumps(voicelist).encode()) #some random voices for them to enjoy
elif clean_path.endswith('/speakers'): #xtts compatible
response_body = (json.dumps([{"name":"kobo","voice_id":"kobo","preview_url":""},{"name":"cheery","voice_id":"cheery","preview_url":""},{"name":"sleepy","voice_id":"sleepy","preview_url":""},{"name":"shouty","voice_id":"shouty","preview_url":""},{"name":"chatty","voice_id":"chatty","preview_url":""}]).encode()) #some random voices for them to enjoy
tmplist = []
for itm in voicelist:
tmplist.append({"name":itm,"voice_id":itm,"preview_url":""})
response_body = (json.dumps(tmplist).encode()) #some random voices for them to enjoy
elif clean_path.endswith('/v1/audio/voices') or clean_path=='/audio/voices':
response_body = (json.dumps({"status":"ok","voices":["kobo","cheery","sleepy","shouty","chatty"]}).encode()) #some random voices for them to enjoy
response_body = (json.dumps({"status":"ok","voices":voicelist}).encode()) #some random voices for them to enjoy
elif clean_path.endswith('/get_tts_settings'): #xtts compatible
response_body = (json.dumps({"temperature":0.75,"speed":1,"length_penalty":1,"repetition_penalty":1,"top_p":1,"top_k":4,"enable_text_splitting":True,"stream_chunk_size":100}).encode()) #some random voices for them to enjoy
@ -5662,6 +5668,7 @@ def show_gui():
ttsgpu_var = ctk.IntVar(value=0)
tts_threads_var = ctk.StringVar(value=str(default_threads))
ttsmaxlen_var = ctk.StringVar(value=str(default_ttsmaxlen))
tts_dir_var = ctk.StringVar()
embeddings_model_var = ctk.StringVar()
embeddings_ctx_var = ctk.StringVar(value=str(""))
@ -6459,6 +6466,8 @@ def show_gui():
ttsgpu_var.trace_add("write", gui_changed_modelfile)
makefileentry(audio_tab, "WavTokenizer Model (Required for some models):", "Select WavTokenizer GGUF Model File", wavtokenizer_var, 11, width=280, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select a WavTokenizer GGUF model file on disk to be loaded for Narration.")
wavtokenizer_var.trace_add("write", gui_changed_modelfile)
makefileentry(audio_tab, "TTS Voices Dir:", "Select directory containing voices for voice cloning", tts_dir_var, 20, width=280, singlerow=True, dialog_type=2, tooltiptxt="Select directory containing voices for voice cloning")
admin_tab = tabcontent["Admin"]
def toggleadmin(a,b,c):
@ -6771,6 +6780,7 @@ def show_gui():
args.ttswavtokenizer = wavtokenizer_var.get()
args.ttsgpu = (ttsgpu_var.get()==1)
args.ttsmaxlen = (default_ttsmaxlen if ttsmaxlen_var.get()=="" else int(ttsmaxlen_var.get()))
args.ttsdir = tts_dir_var.get()
args.admin = (admin_var.get()==1 and not args.cli)
args.admindir = admin_dir_var.get()
@ -7012,6 +7022,7 @@ def show_gui():
wavtokenizer_var.set(dict["ttswavtokenizer"] if ("ttswavtokenizer" in dict and dict["ttswavtokenizer"]) else "")
ttsgpu_var.set(dict["ttsgpu"] if ("ttsgpu" in dict) else 0)
ttsmaxlen_var.set(str(dict["ttsmaxlen"]) if ("ttsmaxlen" in dict and dict["ttsmaxlen"]) else str(default_ttsmaxlen))
tts_dir_var.set(dict["ttsdir"] if ("ttsdir" in dict and dict["ttsdir"]) else "")
embeddings_model_var.set(dict["embeddingsmodel"] if ("embeddingsmodel" in dict and dict["embeddingsmodel"]) else "")
embeddings_ctx_var.set(str(dict["embeddingsmaxctx"]) if ("embeddingsmaxctx" in dict and dict["embeddingsmaxctx"]) else "")
@ -8703,6 +8714,39 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
except Exception:
print("Could not find Embedded llama.cpp UI.")
# load all TTS audio files
if args.ttsdir and args.ttsmodel and os.path.isdir(args.ttsdir):
try:
global voicebank, voicelist
voicebank = {}
voicecount = 0
voicelist = []
voicelist.append("kobo")
voicebank["kobo"] = ""
voicelist.append("cheery")
voicebank["cheery"] = ""
voicelist.append("sleepy")
voicebank["sleepy"] = ""
voicelist.append("shouty")
voicebank["shouty"] = ""
voicelist.append("chatty")
voicebank["chatty"] = ""
voicelist.append("random")
voicebank["random"] = ""
for filename in os.listdir(args.ttsdir):
if filename.lower().endswith((".mp3", ".wav")):
full_path = os.path.join(args.ttsdir, filename)
with open(full_path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("utf-8")
voicebank[filename] = encoded
voicecount += 1
voicelist.append(os.path.basename(filename))
print(f"Loaded {voicecount} TTS voices.")
except Exception:
print("Could not load TTS voices.")
if args.mcpfile and isinstance(args.mcpfile, str):
threading.Thread(target=load_mcp_async, args=(args,), daemon=True).start()
time.sleep(0.2) # short delay to allow get_capabilities to work
@ -9065,6 +9109,7 @@ if __name__ == '__main__':
ttsparsergroup.add_argument("--ttsgpu", help="Use the GPU for TTS.", action='store_true')
ttsparsergroup.add_argument("--ttsmaxlen", help="Limit number of audio tokens generated with TTS.", type=int, default=default_ttsmaxlen)
ttsparsergroup.add_argument("--ttsthreads", metavar=('[threads]'), help="Use a different number of threads for TTS if specified. Otherwise, has the same value as --threads.", type=int, default=0)
ttsparsergroup.add_argument("--ttsdir", metavar=('[directory]'), help="Select directory containing voices for voice cloning.", default="")
embeddingsparsergroup = parser.add_argument_group('Embeddings Model Commands')
embeddingsparsergroup.add_argument("--embeddingsmodel", metavar=('[filename]'), help="Specify an embeddings model to be loaded for generating embedding vectors.", default="")

View file

@ -68,6 +68,11 @@ Qwen3TTS::Qwen3TTS() = default;
Qwen3TTS::~Qwen3TTS() = default;
void Qwen3TTS::set_seed(int seed)
{
this->transformer_.set_seed(seed);
}
bool Qwen3TTS::load_models(const std::string & model_dir) {
// Construct model paths
std::string tts_model_path = model_dir + "/qwen3-tts-0.6b-f16.gguf";
@ -197,9 +202,10 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
return synthesize_with_voice(text, ref_samples.data(), (int32_t)ref_samples.size(), params);
}
static std::vector<float> speaker_embedding;
tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
const float * ref_samples, int32_t n_ref_samples,
const tts_params & params) {
const tts_params & params, bool regenerate) {
tts_result result;
if (!models_loaded_) {
@ -226,11 +232,14 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
}
int64_t t_encode_start = get_time_ms();
std::vector<float> speaker_embedding;
if (!audio_encoder_.encode(ref_samples, n_ref_samples, speaker_embedding)) {
result.error_msg = "Failed to extract speaker embedding: " + audio_encoder_.get_error();
return result;
if(speaker_embedding.size()==0 || regenerate)
{
speaker_embedding.clear();
if (!audio_encoder_.encode(ref_samples, n_ref_samples, speaker_embedding)) {
result.error_msg = "Failed to extract speaker embedding: " + audio_encoder_.get_error();
return result;
}
}
result.t_encode_ms = get_time_ms() - t_encode_start;

View file

@ -81,6 +81,8 @@ public:
Qwen3TTS();
~Qwen3TTS();
void set_seed(int seed);
// Load all models from directory
// model_dir should contain: transformer.gguf, tokenizer.gguf, vocoder.gguf
bool load_models(const std::string & model_dir);
@ -107,7 +109,7 @@ public:
// params: generation parameters
tts_result synthesize_with_voice(const std::string & text,
const float * ref_samples, int32_t n_ref_samples,
const tts_params & params = tts_params());
const tts_params & params = tts_params(), bool regenerate=true);
// Set progress callback
void set_progress_callback(tts_progress_callback_t callback);

View file

@ -49,6 +49,15 @@ void TTSTransformer::unload_model() {
embd_row_fp16_scratch_.clear();
}
void TTSTransformer::set_seed(int seed)
{
if (seed <= 0 || seed==0xFFFFFFFF)
{
seed = (((uint32_t)time(NULL)) % 1000000u);
}
this->rng_ = std::mt19937(seed);
}
bool TTSTransformer::load_model(const std::string & model_path) {
unload_model();

View file

@ -59,7 +59,7 @@ struct tts_transformer_config {
// Text embedding
int32_t text_vocab_size = 151936;
int32_t text_embd_dim = 2048;
// Talker transformer
int32_t hidden_size = 1024;
int32_t n_layers = 28;
@ -69,18 +69,18 @@ struct tts_transformer_config {
int32_t head_dim = 128;
float rms_norm_eps = 1e-6f;
float rope_theta = 1000000.0f;
// M-RoPE sections [time, freq, channel] = [24, 20, 20]
int32_t mrope_section[3] = {24, 20, 20};
// Codec vocabulary
int32_t codec_vocab_size = 3072; // talker.codec_embd/codec_head
int32_t n_codebooks = 16;
// Code predictor
int32_t code_pred_layers = 5;
int32_t code_pred_vocab_size = 2048; // Per-codebook vocab
// Special codec tokens
int32_t codec_pad_id = 2148;
int32_t codec_bos_id = 2149;
@ -101,16 +101,16 @@ struct tts_transformer_config {
// Transformer layer weights
struct transformer_layer {
struct ggml_tensor * attn_norm = nullptr;
struct ggml_tensor * attn_q = nullptr;
struct ggml_tensor * attn_k = nullptr;
struct ggml_tensor * attn_v = nullptr;
struct ggml_tensor * attn_output = nullptr;
struct ggml_tensor * attn_q_norm = nullptr;
struct ggml_tensor * attn_k_norm = nullptr;
struct ggml_tensor * ffn_norm = nullptr;
struct ggml_tensor * ffn_gate = nullptr;
struct ggml_tensor * ffn_up = nullptr;
struct ggml_tensor * ffn_down = nullptr;
@ -119,42 +119,42 @@ struct transformer_layer {
// TTS Transformer model weights
struct tts_transformer_model {
tts_transformer_config config;
// Text embedding and projection
struct ggml_tensor * text_embd = nullptr; // [text_embd_dim, text_vocab_size]
struct ggml_tensor * text_proj_fc1 = nullptr; // [text_embd_dim, text_embd_dim]
struct ggml_tensor * text_proj_fc1_bias = nullptr;
struct ggml_tensor * text_proj_fc2 = nullptr; // [text_embd_dim, hidden_size]
struct ggml_tensor * text_proj_fc2_bias = nullptr;
// Codec embedding (for autoregressive input)
struct ggml_tensor * codec_embd = nullptr; // [hidden_size, codec_vocab_size]
// Talker transformer layers
std::vector<transformer_layer> layers;
// Final RMSNorm
struct ggml_tensor * output_norm = nullptr; // [hidden_size]
// Codec head (for first codebook prediction)
struct ggml_tensor * codec_head = nullptr; // [hidden_size, codec_vocab_size]
// Code predictor layers
std::vector<transformer_layer> code_pred_layers;
// Code predictor output norm (final RMS norm before lm_head)
struct ggml_tensor * code_pred_output_norm = nullptr; // [hidden_size]
// Code predictor per-codebook embeddings and heads (15 codebooks, 0 uses talker output)
std::vector<struct ggml_tensor *> code_pred_embd; // [hidden_size, code_pred_vocab_size] x 15
std::vector<struct ggml_tensor *> code_pred_head; // [hidden_size, code_pred_vocab_size] x 15
// GGML context for tensor metadata
struct ggml_context * ctx = nullptr;
// Backend buffer for weights
ggml_backend_buffer_t buffer = nullptr;
// Tensor name to tensor mapping
std::map<std::string, struct ggml_tensor *> tensors;
};
@ -163,10 +163,10 @@ struct tts_transformer_model {
struct tts_kv_cache {
std::vector<struct ggml_tensor *> k_cache;
std::vector<struct ggml_tensor *> v_cache;
struct ggml_context * ctx = nullptr;
ggml_backend_buffer_t buffer = nullptr;
int32_t n_ctx = 0;
int32_t n_used = 0;
int32_t head_dim = 128;
@ -179,9 +179,9 @@ struct tts_transformer_state {
ggml_backend_t backend = nullptr;
ggml_backend_t backend_cpu = nullptr;
ggml_backend_sched_t sched = nullptr;
std::vector<uint8_t> compute_meta;
tts_kv_cache cache; // Talker KV cache (28 layers)
tts_kv_cache code_pred_cache; // Code predictor KV cache (5 layers)
};
@ -191,25 +191,27 @@ class TTSTransformer {
public:
TTSTransformer();
~TTSTransformer();
void set_seed(int seed);
// Load model from GGUF file
bool load_model(const std::string & model_path);
// Release all model/runtime resources
void unload_model();
// Initialize KV cache
bool init_kv_cache(int32_t n_ctx);
// Clear KV cache
void clear_kv_cache();
// Initialize code predictor KV cache (5 layers, max 16 context)
bool init_code_pred_kv_cache(int32_t n_ctx);
// Clear code predictor KV cache
void clear_code_pred_kv_cache();
// Forward pass for text tokens (prefill phase)
// text_tokens: input text token IDs [n_tokens]
// speaker_embd: speaker embedding [hidden_size] (optional, can be nullptr)
@ -222,7 +224,7 @@ public:
bool forward_prefill(const float * prefill_embd, int32_t n_tokens,
int32_t n_past, std::vector<float> & output,
std::vector<float> * logits_out = nullptr);
// Forward pass for codec tokens (generation phase)
// codec_token: single codec token for first codebook
// n_past: number of tokens already in KV cache
@ -233,26 +235,26 @@ public:
bool forward_step(const float * step_embd, int32_t n_past,
std::vector<float> & output,
std::vector<float> * hidden_out = nullptr);
// Get hidden states from last forward pass (for code predictor)
bool get_hidden_states(std::vector<float> & hidden) const;
// Run code predictor to get all 16 codebook predictions
// hidden: hidden states from talker [hidden_size]
// prev_codes: previous codes for codebooks 1-15 (can be nullptr for first step)
// output: logits for all 16 codebooks [16, code_pred_vocab_size]
bool predict_codes(const float * hidden, const int32_t * prev_codes,
std::vector<float> & output);
// Run code predictor autoregressively to generate 15 codes (codebooks 1-15)
// hidden: hidden states from talker [hidden_size]
// codebook_0_token: the codebook 0 token (used to create 2-token prefill input)
// output: generated codes for codebooks 1-15 [15]
bool predict_codes_autoregressive(const float * hidden, int32_t codebook_0_token,
bool predict_codes_autoregressive(const float * hidden, int32_t codebook_0_token,
std::vector<int32_t> & output,
float temperature = 0.9f,
int32_t top_k = 50);
// Generate speech codes autoregressively
// text_tokens: input text token IDs [n_tokens]
// speaker_embd: speaker embedding [hidden_size]
@ -265,20 +267,20 @@ public:
float repetition_penalty = 1.05f,
float temperature = 0.9f,
int32_t top_k = 50);
const tts_transformer_config & get_config() const { return model_.config; }
const std::string & get_error() const { return error_msg_; }
// Legacy interface for compatibility
bool forward(const int32_t * tokens, int32_t n_tokens, int32_t n_past,
std::vector<float> & output);
bool forward_with_audio(const int32_t * tokens, int32_t n_tokens,
const float * audio_embd, int32_t n_audio,
int32_t audio_start_pos, int32_t n_past,
std::vector<float> & output);
private:
bool try_init_coreml_code_predictor(const std::string & model_path);
bool predict_codes_autoregressive_coreml(const float * hidden, int32_t codebook_0_token,
@ -304,32 +306,32 @@ private:
const char * output_name, std::vector<float> & output);
bool lookup_single_embedding_row(struct ggml_tensor * embedding, int32_t token_id,
float * out_row);
// Build computation graph for code predictor
struct ggml_cgraph * build_code_pred_graph(int32_t n_prev_codes);
// Build computation graph for single-step autoregressive code predictor
// n_past: number of tokens already in KV cache (0-14)
// generation_step: which codebook we're predicting (0-14)
struct ggml_cgraph * build_code_pred_step_graph(int32_t n_past, int32_t generation_step);
// Build computation graph for 2-token prefill of code predictor
// Processes [past_hidden, codec_embd(codebook_0_token)] together
struct ggml_cgraph * build_code_pred_prefill_graph();
// Parse hyperparameters from GGUF
bool parse_config(struct gguf_context * ctx);
// Create tensor structures
bool create_tensors(struct gguf_context * ctx);
// Load tensor data from file
bool load_tensor_data(const std::string & path, struct gguf_context * ctx);
tts_transformer_model model_;
tts_transformer_state state_;
std::string error_msg_;
// Cached hidden states from last forward pass
std::vector<float> last_hidden_;
std::vector<ggml_fp16_t> embd_row_fp16_scratch_;

File diff suppressed because one or more lines are too long