added CustomVoice support

This commit is contained in:
Concedo 2026-03-23 18:50:08 +08:00
parent abe55fa424
commit 0d50cafd8b
8 changed files with 111 additions and 48 deletions

View file

@ -298,6 +298,7 @@ struct tts_generation_inputs
const char * custom_speaker_text = "";
const char * custom_speaker_data = "";
const char * reference_audio = "";
const char * speaker_instruction = "";
};
struct tts_generation_outputs
{

View file

@ -424,7 +424,8 @@ class tts_generation_inputs(ctypes.Structure):
("custom_speaker_voice", ctypes.c_char_p),
("custom_speaker_text", ctypes.c_char_p),
("custom_speaker_data", ctypes.c_char_p),
("reference_audio", ctypes.c_char_p)]
("reference_audio", ctypes.c_char_p),
("speaker_instruction", ctypes.c_char_p)]
class tts_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@ -2538,6 +2539,14 @@ def tts_prepare_voice_json(jsonstr):
except Exception:
return None
def tts_extract_instruction(x):
match = re.match(r'^\[([^\]]+)\]\s*(.+)$', x)
if match:
instruction = match.group(1)
x1 = match.group(2)
return x1, instruction
return x, ""
def tts_generate(genparams):
global args, voicebank, voicelist
prompt = genparams.get("input", genparams.get("text", ""))
@ -2558,6 +2567,11 @@ def tts_generate(genparams):
voice = simple_lcg_hash(voicestr.strip()) if voicestr else 1
inputs = tts_generation_inputs()
inputs.custom_speaker_voice = normalized_voice.encode("UTF-8")
ttsinstruction = genparams.get("instruction", "")
# if no instruction provided, extract from text
if not genparams.get("instruction", ""):
prompt, ttsinstruction = tts_extract_instruction(prompt)
inputs.speaker_instruction = ttsinstruction.encode("UTF-8")
inputs.prompt = prompt.encode("UTF-8")
inputs.speaker_seed = voice
aseed = -1
@ -9738,6 +9752,8 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
voicelist.append("random")
voicebank["random"] = ""
voicelist.append("instruct")
voicebank["instruct"] = ""
if args.ttsdir and os.path.isdir(args.ttsdir):
for filename in os.listdir(args.ttsdir):

View file

@ -30,13 +30,13 @@ int main(int argc, char ** argv) {
std::string text;
std::string output_file = "output.wav";
std::string reference_audio;
qwen3_tts::tts_params params;
// Parse arguments
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
print_usage(argv[0]);
return 0;
@ -106,63 +106,63 @@ int main(int argc, char ** argv) {
return 1;
}
}
// Validate required arguments
if (model_dir.empty()) {
fprintf(stderr, "Error: model directory is required\n");
print_usage(argv[0]);
return 1;
}
if (text.empty()) {
fprintf(stderr, "Error: text is required\n");
print_usage(argv[0]);
return 1;
}
// Initialize TTS
qwen3_tts::Qwen3TTS tts;
fprintf(stderr, "Loading models from: %s\n", model_dir.c_str());
if (!tts.load_models(model_dir)) {
fprintf(stderr, "Error: %s\n", tts.get_error().c_str());
return 1;
}
// Set progress callback
tts.set_progress_callback([](int tokens, int max_tokens) {
fprintf(stderr, "\rGenerating: %d/%d tokens", tokens, max_tokens);
});
// Generate speech
qwen3_tts::tts_result result;
if (reference_audio.empty()) {
fprintf(stderr, "Synthesizing: \"%s\"\n", text.c_str());
result = tts.synthesize(text, params);
result = tts.synthesize(text,"", params);
} else {
fprintf(stderr, "Synthesizing with voice cloning: \"%s\"\n", text.c_str());
fprintf(stderr, "Reference audio: %s\n", reference_audio.c_str());
result = tts.synthesize_with_voice(text, reference_audio, params);
}
if (!result.success) {
fprintf(stderr, "\nError: %s\n", result.error_msg.c_str());
return 1;
}
fprintf(stderr, "\n");
// Save output
if (!qwen3_tts::save_audio_file(output_file, result.audio, result.sample_rate)) {
fprintf(stderr, "Error: failed to save output file: %s\n", output_file.c_str());
return 1;
}
fprintf(stderr, "Output saved to: %s\n", output_file.c_str());
fprintf(stderr, "Audio duration: %.2f seconds\n",
fprintf(stderr, "Audio duration: %.2f seconds\n",
(float)result.audio.size() / result.sample_rate);
// Print timing
if (params.print_timing) {
fprintf(stderr, "\nTiming:\n");
@ -173,6 +173,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, " Decode: %6lld ms\n", (long long)result.t_decode_ms);
fprintf(stderr, " Total: %6lld ms\n", (long long)result.t_total_ms);
}
return 0;
}

View file

@ -163,7 +163,7 @@ bool Qwen3TTS::load_models(const std::string & tts_model_path, const std::string
return true;
}
tts_result Qwen3TTS::synthesize(const std::string & text,
tts_result Qwen3TTS::synthesize(const std::string & text, const std::string & instruction,
const tts_params & params) {
tts_result result;
@ -176,7 +176,7 @@ tts_result Qwen3TTS::synthesize(const std::string & text,
// This will use the model's default voice characteristics
std::vector<float> zero_embedding(transformer_.get_config().hidden_size, 0.0f);
return synthesize_internal(text, zero_embedding.data(), params, result);
return synthesize_internal(text, instruction, zero_embedding.data(), params, result);
}
tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
@ -260,10 +260,10 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
fprintf(stderr, "Speaker embedding extracted: %zu floats\n", speaker_embedding.size());
}
return synthesize_internal(text, speaker_embedding.data(), params, result);
return synthesize_internal(text, "", speaker_embedding.data(), params, result);
}
tts_result Qwen3TTS::synthesize_internal(const std::string & text,
tts_result Qwen3TTS::synthesize_internal(const std::string & text, const std::string & instruction,
const float * speaker_embedding,
const tts_params & params,
tts_result & result) {
@ -311,11 +311,21 @@ tts_result Qwen3TTS::synthesize_internal(const std::string & text,
}
transformer_.clear_kv_cache();
std::vector<int32_t> alignment_instruct_tokens;
int instruct_tok_count = 0;
int32_t * instruct_tok_data = nullptr;
if(instruction!="")
{
alignment_instruct_tokens = tokenizer_.encode_instruct(instruction);
instruct_tok_data = alignment_instruct_tokens.data();
instruct_tok_count = alignment_instruct_tokens.size();
}
std::vector<int32_t> speech_codes;
if (!transformer_.generate(text_tokens.data(), (int32_t)text_tokens.size(),
speaker_embedding, params.max_audio_tokens, speech_codes,
2050, params.repetition_penalty,
params.temperature, params.top_k)) {
params.temperature, params.top_k, -1, instruct_tok_data, instruct_tok_count)) {
result.error_msg = "Failed to generate speech codes: " + transformer_.get_error();
return result;
}

View file

@ -91,7 +91,7 @@ public:
// Generate speech from text
// text: input text to synthesize
// params: generation parameters
tts_result synthesize(const std::string & text,
tts_result synthesize(const std::string & text, const std::string & instruction,
const tts_params & params = tts_params());
// Generate speech with voice cloning
@ -121,7 +121,7 @@ public:
bool is_loaded() const { return models_loaded_; }
private:
tts_result synthesize_internal(const std::string & text,
tts_result synthesize_internal(const std::string & text, const std::string & instruction,
const float * speaker_embedding,
const tts_params & params,
tts_result & result);

View file

@ -290,6 +290,37 @@ std::vector<int32_t> TextTokenizer::encode(const std::string & text) const {
return tokens;
}
std::vector<int32_t> TextTokenizer::encode_instruct(const std::string & instruct) const {
if (!loaded_ || instruct.empty()) {
return {};
}
// Format: <|im_start|>user\n{instruct}<|im_end|>\n
std::vector<int32_t> tokens;
// <|im_start|>
tokens.push_back(config_.bos_token_id);
// user
int user_token_id_ = 872;
tokens.push_back(user_token_id_);
// \n
tokens.push_back(newline_token_id_);
// Encode the instruct
auto text_tokens = encode(instruct);
tokens.insert(tokens.end(), text_tokens.begin(), text_tokens.end());
// <|im_end|>
tokens.push_back(config_.eos_token_id);
// \n
tokens.push_back(newline_token_id_);
return tokens;
}
std::vector<int32_t> TextTokenizer::encode_for_tts(const std::string & text) const {
if (!loaded_) {
return {};

View file

@ -22,64 +22,66 @@ class TextTokenizer {
public:
TextTokenizer();
~TextTokenizer();
// Load tokenizer from GGUF file
bool load_from_gguf(struct gguf_context * ctx);
// Encode text to token IDs
std::vector<int32_t> encode(const std::string & text) const;
// Encode with TTS format: <|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n
std::vector<int32_t> encode_for_tts(const std::string & text) const;
std::vector<int32_t> encode_instruct(const std::string & instruct) const;
// Decode token IDs to text
std::string decode(const std::vector<int32_t> & tokens) const;
// Decode single token
std::string decode_token(int32_t token_id) const;
// Get configuration
const tokenizer_config & get_config() const { return config_; }
// Get error message
const std::string & get_error() const { return error_msg_; }
// Check if loaded
bool is_loaded() const { return loaded_; }
// Get special token IDs
int32_t bos_token_id() const { return config_.bos_token_id; }
int32_t eos_token_id() const { return config_.eos_token_id; }
int32_t pad_token_id() const { return config_.pad_token_id; }
private:
tokenizer_config config_;
std::string error_msg_;
bool loaded_ = false;
// Vocabulary: token string -> token ID
std::unordered_map<std::string, int32_t> vocab_;
// Reverse vocabulary: token ID -> token string
std::vector<std::string> id_to_token_;
// BPE merges: pair -> rank (lower rank = higher priority)
std::map<std::pair<std::string, std::string>, int32_t> bpe_ranks_;
// Special token for "assistant" and newline
int32_t assistant_token_id_ = 77091;
int32_t newline_token_id_ = 198; // '\n' encoded
// Helper: convert bytes to unicode (GPT-2 style byte encoding)
static std::string bytes_to_unicode(const std::string & text);
static std::string unicode_to_bytes(const std::string & text);
// Helper: get UTF-8 character length
static size_t utf8_len(char c);
// BPE encoding for a single word
std::vector<std::string> bpe(const std::string & token) const;
// Find the pair with lowest rank in a sequence
std::pair<std::string, std::string> get_min_pair(
const std::vector<std::string> & word) const;

View file

@ -1184,7 +1184,7 @@ static tts_generation_outputs ttstype_generate_qwen3tts(const tts_generation_inp
qwen3_tts::tts_params qwen3tts_params;
std::string custom_reference_audio_str = inputs.reference_audio;
std::vector<float> custom_reference_audio_pcmf32;
std::string speakerstr = inputs.custom_speaker_voice;
std::string speaker_instruction = inputs.speaker_instruction;
int audio_seed = inputs.audio_seed;
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF)
@ -1194,7 +1194,7 @@ static tts_generation_outputs ttstype_generate_qwen3tts(const tts_generation_inp
if(ttsdebugmode==1 && !tts_is_quiet)
{
printf("\nUsing Audio Seed: %d, Speaker: %s", audio_seed, speakerstr.c_str());
printf("\nUsing Audio Seed: %d", audio_seed);
}
qwen3tts_runner.set_seed(audio_seed);
@ -1221,8 +1221,11 @@ static tts_generation_outputs ttstype_generate_qwen3tts(const tts_generation_inp
qwen3tts_params.print_progress = true;
}
if (custom_reference_audio_pcmf32.empty()) {
result = qwen3tts_runner.synthesize(prompt, qwen3tts_params);
if (speaker_instruction!="" || custom_reference_audio_pcmf32.empty()) {
if (speaker_instruction != "" && !tts_is_quiet) {
printf("\nApply VoiceDesign Instruction: %s", speaker_instruction.c_str());
}
result = qwen3tts_runner.synthesize(prompt, speaker_instruction, qwen3tts_params);
} else {
std::size_t reuse_hash_value = std::hash<std::string>{}(custom_reference_audio_str);