mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 16:31:59 +00:00
added CustomVoice support
This commit is contained in:
parent
abe55fa424
commit
0d50cafd8b
8 changed files with 111 additions and 48 deletions
1
expose.h
1
expose.h
|
|
@ -298,6 +298,7 @@ struct tts_generation_inputs
|
|||
const char * custom_speaker_text = "";
|
||||
const char * custom_speaker_data = "";
|
||||
const char * reference_audio = "";
|
||||
const char * speaker_instruction = "";
|
||||
};
|
||||
struct tts_generation_outputs
|
||||
{
|
||||
|
|
|
|||
18
koboldcpp.py
18
koboldcpp.py
|
|
@ -424,7 +424,8 @@ class tts_generation_inputs(ctypes.Structure):
|
|||
("custom_speaker_voice", ctypes.c_char_p),
|
||||
("custom_speaker_text", ctypes.c_char_p),
|
||||
("custom_speaker_data", ctypes.c_char_p),
|
||||
("reference_audio", ctypes.c_char_p)]
|
||||
("reference_audio", ctypes.c_char_p),
|
||||
("speaker_instruction", ctypes.c_char_p)]
|
||||
|
||||
class tts_generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
|
@ -2538,6 +2539,14 @@ def tts_prepare_voice_json(jsonstr):
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def tts_extract_instruction(x):
|
||||
match = re.match(r'^\[([^\]]+)\]\s*(.+)$', x)
|
||||
if match:
|
||||
instruction = match.group(1)
|
||||
x1 = match.group(2)
|
||||
return x1, instruction
|
||||
return x, ""
|
||||
|
||||
def tts_generate(genparams):
|
||||
global args, voicebank, voicelist
|
||||
prompt = genparams.get("input", genparams.get("text", ""))
|
||||
|
|
@ -2558,6 +2567,11 @@ def tts_generate(genparams):
|
|||
voice = simple_lcg_hash(voicestr.strip()) if voicestr else 1
|
||||
inputs = tts_generation_inputs()
|
||||
inputs.custom_speaker_voice = normalized_voice.encode("UTF-8")
|
||||
ttsinstruction = genparams.get("instruction", "")
|
||||
# if no instruction provided, extract from text
|
||||
if not genparams.get("instruction", ""):
|
||||
prompt, ttsinstruction = tts_extract_instruction(prompt)
|
||||
inputs.speaker_instruction = ttsinstruction.encode("UTF-8")
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.speaker_seed = voice
|
||||
aseed = -1
|
||||
|
|
@ -9738,6 +9752,8 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False):
|
|||
|
||||
voicelist.append("random")
|
||||
voicebank["random"] = ""
|
||||
voicelist.append("instruct")
|
||||
voicebank["instruct"] = ""
|
||||
|
||||
if args.ttsdir and os.path.isdir(args.ttsdir):
|
||||
for filename in os.listdir(args.ttsdir):
|
||||
|
|
|
|||
|
|
@ -30,13 +30,13 @@ int main(int argc, char ** argv) {
|
|||
std::string text;
|
||||
std::string output_file = "output.wav";
|
||||
std::string reference_audio;
|
||||
|
||||
|
||||
qwen3_tts::tts_params params;
|
||||
|
||||
|
||||
// Parse arguments
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argv[0]);
|
||||
return 0;
|
||||
|
|
@ -106,63 +106,63 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Validate required arguments
|
||||
if (model_dir.empty()) {
|
||||
fprintf(stderr, "Error: model directory is required\n");
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
if (text.empty()) {
|
||||
fprintf(stderr, "Error: text is required\n");
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
// Initialize TTS
|
||||
qwen3_tts::Qwen3TTS tts;
|
||||
|
||||
|
||||
fprintf(stderr, "Loading models from: %s\n", model_dir.c_str());
|
||||
if (!tts.load_models(model_dir)) {
|
||||
fprintf(stderr, "Error: %s\n", tts.get_error().c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
// Set progress callback
|
||||
tts.set_progress_callback([](int tokens, int max_tokens) {
|
||||
fprintf(stderr, "\rGenerating: %d/%d tokens", tokens, max_tokens);
|
||||
});
|
||||
|
||||
|
||||
// Generate speech
|
||||
qwen3_tts::tts_result result;
|
||||
|
||||
|
||||
if (reference_audio.empty()) {
|
||||
fprintf(stderr, "Synthesizing: \"%s\"\n", text.c_str());
|
||||
result = tts.synthesize(text, params);
|
||||
result = tts.synthesize(text,"", params);
|
||||
} else {
|
||||
fprintf(stderr, "Synthesizing with voice cloning: \"%s\"\n", text.c_str());
|
||||
fprintf(stderr, "Reference audio: %s\n", reference_audio.c_str());
|
||||
result = tts.synthesize_with_voice(text, reference_audio, params);
|
||||
}
|
||||
|
||||
|
||||
if (!result.success) {
|
||||
fprintf(stderr, "\nError: %s\n", result.error_msg.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
|
||||
// Save output
|
||||
if (!qwen3_tts::save_audio_file(output_file, result.audio, result.sample_rate)) {
|
||||
fprintf(stderr, "Error: failed to save output file: %s\n", output_file.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
fprintf(stderr, "Output saved to: %s\n", output_file.c_str());
|
||||
fprintf(stderr, "Audio duration: %.2f seconds\n",
|
||||
fprintf(stderr, "Audio duration: %.2f seconds\n",
|
||||
(float)result.audio.size() / result.sample_rate);
|
||||
|
||||
|
||||
// Print timing
|
||||
if (params.print_timing) {
|
||||
fprintf(stderr, "\nTiming:\n");
|
||||
|
|
@ -173,6 +173,6 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, " Decode: %6lld ms\n", (long long)result.t_decode_ms);
|
||||
fprintf(stderr, " Total: %6lld ms\n", (long long)result.t_total_ms);
|
||||
}
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ bool Qwen3TTS::load_models(const std::string & tts_model_path, const std::string
|
|||
return true;
|
||||
}
|
||||
|
||||
tts_result Qwen3TTS::synthesize(const std::string & text,
|
||||
tts_result Qwen3TTS::synthesize(const std::string & text, const std::string & instruction,
|
||||
const tts_params & params) {
|
||||
tts_result result;
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ tts_result Qwen3TTS::synthesize(const std::string & text,
|
|||
// This will use the model's default voice characteristics
|
||||
std::vector<float> zero_embedding(transformer_.get_config().hidden_size, 0.0f);
|
||||
|
||||
return synthesize_internal(text, zero_embedding.data(), params, result);
|
||||
return synthesize_internal(text, instruction, zero_embedding.data(), params, result);
|
||||
}
|
||||
|
||||
tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
|
||||
|
|
@ -260,10 +260,10 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
|
|||
fprintf(stderr, "Speaker embedding extracted: %zu floats\n", speaker_embedding.size());
|
||||
}
|
||||
|
||||
return synthesize_internal(text, speaker_embedding.data(), params, result);
|
||||
return synthesize_internal(text, "", speaker_embedding.data(), params, result);
|
||||
}
|
||||
|
||||
tts_result Qwen3TTS::synthesize_internal(const std::string & text,
|
||||
tts_result Qwen3TTS::synthesize_internal(const std::string & text, const std::string & instruction,
|
||||
const float * speaker_embedding,
|
||||
const tts_params & params,
|
||||
tts_result & result) {
|
||||
|
|
@ -311,11 +311,21 @@ tts_result Qwen3TTS::synthesize_internal(const std::string & text,
|
|||
}
|
||||
transformer_.clear_kv_cache();
|
||||
|
||||
std::vector<int32_t> alignment_instruct_tokens;
|
||||
int instruct_tok_count = 0;
|
||||
int32_t * instruct_tok_data = nullptr;
|
||||
if(instruction!="")
|
||||
{
|
||||
alignment_instruct_tokens = tokenizer_.encode_instruct(instruction);
|
||||
instruct_tok_data = alignment_instruct_tokens.data();
|
||||
instruct_tok_count = alignment_instruct_tokens.size();
|
||||
}
|
||||
|
||||
std::vector<int32_t> speech_codes;
|
||||
if (!transformer_.generate(text_tokens.data(), (int32_t)text_tokens.size(),
|
||||
speaker_embedding, params.max_audio_tokens, speech_codes,
|
||||
2050, params.repetition_penalty,
|
||||
params.temperature, params.top_k)) {
|
||||
params.temperature, params.top_k, -1, instruct_tok_data, instruct_tok_count)) {
|
||||
result.error_msg = "Failed to generate speech codes: " + transformer_.get_error();
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ public:
|
|||
// Generate speech from text
|
||||
// text: input text to synthesize
|
||||
// params: generation parameters
|
||||
tts_result synthesize(const std::string & text,
|
||||
tts_result synthesize(const std::string & text, const std::string & instruction,
|
||||
const tts_params & params = tts_params());
|
||||
|
||||
// Generate speech with voice cloning
|
||||
|
|
@ -121,7 +121,7 @@ public:
|
|||
bool is_loaded() const { return models_loaded_; }
|
||||
|
||||
private:
|
||||
tts_result synthesize_internal(const std::string & text,
|
||||
tts_result synthesize_internal(const std::string & text, const std::string & instruction,
|
||||
const float * speaker_embedding,
|
||||
const tts_params & params,
|
||||
tts_result & result);
|
||||
|
|
|
|||
|
|
@ -290,6 +290,37 @@ std::vector<int32_t> TextTokenizer::encode(const std::string & text) const {
|
|||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int32_t> TextTokenizer::encode_instruct(const std::string & instruct) const {
|
||||
if (!loaded_ || instruct.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Format: <|im_start|>user\n{instruct}<|im_end|>\n
|
||||
std::vector<int32_t> tokens;
|
||||
|
||||
// <|im_start|>
|
||||
tokens.push_back(config_.bos_token_id);
|
||||
|
||||
// user
|
||||
int user_token_id_ = 872;
|
||||
tokens.push_back(user_token_id_);
|
||||
|
||||
// \n
|
||||
tokens.push_back(newline_token_id_);
|
||||
|
||||
// Encode the instruct
|
||||
auto text_tokens = encode(instruct);
|
||||
tokens.insert(tokens.end(), text_tokens.begin(), text_tokens.end());
|
||||
|
||||
// <|im_end|>
|
||||
tokens.push_back(config_.eos_token_id);
|
||||
|
||||
// \n
|
||||
tokens.push_back(newline_token_id_);
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int32_t> TextTokenizer::encode_for_tts(const std::string & text) const {
|
||||
if (!loaded_) {
|
||||
return {};
|
||||
|
|
|
|||
|
|
@ -22,64 +22,66 @@ class TextTokenizer {
|
|||
public:
|
||||
TextTokenizer();
|
||||
~TextTokenizer();
|
||||
|
||||
|
||||
// Load tokenizer from GGUF file
|
||||
bool load_from_gguf(struct gguf_context * ctx);
|
||||
|
||||
|
||||
// Encode text to token IDs
|
||||
std::vector<int32_t> encode(const std::string & text) const;
|
||||
|
||||
|
||||
// Encode with TTS format: <|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n
|
||||
std::vector<int32_t> encode_for_tts(const std::string & text) const;
|
||||
|
||||
|
||||
std::vector<int32_t> encode_instruct(const std::string & instruct) const;
|
||||
|
||||
// Decode token IDs to text
|
||||
std::string decode(const std::vector<int32_t> & tokens) const;
|
||||
|
||||
|
||||
// Decode single token
|
||||
std::string decode_token(int32_t token_id) const;
|
||||
|
||||
|
||||
// Get configuration
|
||||
const tokenizer_config & get_config() const { return config_; }
|
||||
|
||||
|
||||
// Get error message
|
||||
const std::string & get_error() const { return error_msg_; }
|
||||
|
||||
|
||||
// Check if loaded
|
||||
bool is_loaded() const { return loaded_; }
|
||||
|
||||
|
||||
// Get special token IDs
|
||||
int32_t bos_token_id() const { return config_.bos_token_id; }
|
||||
int32_t eos_token_id() const { return config_.eos_token_id; }
|
||||
int32_t pad_token_id() const { return config_.pad_token_id; }
|
||||
|
||||
|
||||
private:
|
||||
tokenizer_config config_;
|
||||
std::string error_msg_;
|
||||
bool loaded_ = false;
|
||||
|
||||
|
||||
// Vocabulary: token string -> token ID
|
||||
std::unordered_map<std::string, int32_t> vocab_;
|
||||
|
||||
|
||||
// Reverse vocabulary: token ID -> token string
|
||||
std::vector<std::string> id_to_token_;
|
||||
|
||||
|
||||
// BPE merges: pair -> rank (lower rank = higher priority)
|
||||
std::map<std::pair<std::string, std::string>, int32_t> bpe_ranks_;
|
||||
|
||||
|
||||
// Special token for "assistant" and newline
|
||||
int32_t assistant_token_id_ = 77091;
|
||||
int32_t newline_token_id_ = 198; // '\n' encoded
|
||||
|
||||
|
||||
// Helper: convert bytes to unicode (GPT-2 style byte encoding)
|
||||
static std::string bytes_to_unicode(const std::string & text);
|
||||
static std::string unicode_to_bytes(const std::string & text);
|
||||
|
||||
|
||||
// Helper: get UTF-8 character length
|
||||
static size_t utf8_len(char c);
|
||||
|
||||
|
||||
// BPE encoding for a single word
|
||||
std::vector<std::string> bpe(const std::string & token) const;
|
||||
|
||||
|
||||
// Find the pair with lowest rank in a sequence
|
||||
std::pair<std::string, std::string> get_min_pair(
|
||||
const std::vector<std::string> & word) const;
|
||||
|
|
|
|||
|
|
@ -1184,7 +1184,7 @@ static tts_generation_outputs ttstype_generate_qwen3tts(const tts_generation_inp
|
|||
qwen3_tts::tts_params qwen3tts_params;
|
||||
std::string custom_reference_audio_str = inputs.reference_audio;
|
||||
std::vector<float> custom_reference_audio_pcmf32;
|
||||
std::string speakerstr = inputs.custom_speaker_voice;
|
||||
std::string speaker_instruction = inputs.speaker_instruction;
|
||||
|
||||
int audio_seed = inputs.audio_seed;
|
||||
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF)
|
||||
|
|
@ -1194,7 +1194,7 @@ static tts_generation_outputs ttstype_generate_qwen3tts(const tts_generation_inp
|
|||
|
||||
if(ttsdebugmode==1 && !tts_is_quiet)
|
||||
{
|
||||
printf("\nUsing Audio Seed: %d, Speaker: %s", audio_seed, speakerstr.c_str());
|
||||
printf("\nUsing Audio Seed: %d", audio_seed);
|
||||
}
|
||||
qwen3tts_runner.set_seed(audio_seed);
|
||||
|
||||
|
|
@ -1221,8 +1221,11 @@ static tts_generation_outputs ttstype_generate_qwen3tts(const tts_generation_inp
|
|||
qwen3tts_params.print_progress = true;
|
||||
}
|
||||
|
||||
if (custom_reference_audio_pcmf32.empty()) {
|
||||
result = qwen3tts_runner.synthesize(prompt, qwen3tts_params);
|
||||
if (speaker_instruction!="" || custom_reference_audio_pcmf32.empty()) {
|
||||
if (speaker_instruction != "" && !tts_is_quiet) {
|
||||
printf("\nApply VoiceDesign Instruction: %s", speaker_instruction.c_str());
|
||||
}
|
||||
result = qwen3tts_runner.synthesize(prompt, speaker_instruction, qwen3tts_params);
|
||||
} else {
|
||||
std::size_t reuse_hash_value = std::hash<std::string>{}(custom_reference_audio_str);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue