mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-08 09:59:50 +00:00
adjust ace step, still wip on caption rework
This commit is contained in:
parent
9ddd74111f
commit
45c74da08b
4 changed files with 37 additions and 15 deletions
|
|
@ -171,6 +171,8 @@ struct AcePrompt {
|
|||
|
||||
static std::mt19937 acestep_lm_rng;
|
||||
static bool acestep_lm_dbg = false;
|
||||
static std::vector<int32_t> forced_tokens;
|
||||
static std::vector<int> caption_tokens = std::vector<int>(); //will be filled with caption tokens
|
||||
|
||||
//
|
||||
// CoT parsing (extract metadata + lyrics from LLM Phase1 output)
|
||||
|
|
@ -347,8 +349,8 @@ static std::string build_cot_yaml(const AcePrompt & prompt) {
|
|||
std::string yaml;
|
||||
if (prompt.bpm > 0)
|
||||
yaml += "bpm: " + std::to_string(prompt.bpm) + "\n";
|
||||
if (!prompt.caption.empty())
|
||||
yaml += yaml_wrap("caption", prompt.caption);
|
||||
// if (!prompt.caption.empty())
|
||||
// yaml += yaml_wrap("caption", prompt.caption);
|
||||
if (prompt.duration > 0)
|
||||
yaml += "duration: " + std::to_string((int)prompt.duration) + "\n";
|
||||
if (!prompt.keyscale.empty())
|
||||
|
|
@ -529,10 +531,10 @@ struct MetadataFSM {
|
|||
for (int v = 30; v <= 300; v++) vals.push_back(std::to_string(v));
|
||||
build_value_tree(bpe, bpm_tree, "bpm:", vals);
|
||||
}
|
||||
// Duration 10-600
|
||||
// Duration 10-300
|
||||
{
|
||||
std::vector<std::string> vals;
|
||||
for (int v = 10; v <= 600; v++) vals.push_back(std::to_string(v));
|
||||
for (int v = 10; v <= 300; v++) vals.push_back(std::to_string(v));
|
||||
build_value_tree(bpe, duration_tree, "duration:", vals);
|
||||
}
|
||||
// Keyscale
|
||||
|
|
@ -674,7 +676,14 @@ struct MetadataFSM {
|
|||
if (name_pos >= (int)name->size()) {
|
||||
switch (state) {
|
||||
case BPM_NAME: state = BPM_VALUE; break;
|
||||
case CAPTION_NAME: state = CAPTION_VALUE; break;
|
||||
case CAPTION_NAME:
|
||||
state = CAPTION_VALUE;
|
||||
if(caption_tokens.size()>0)
|
||||
{
|
||||
forced_tokens.clear();
|
||||
forced_tokens = caption_tokens;
|
||||
}
|
||||
break;
|
||||
case DURATION_NAME: state = DURATION_VALUE; break;
|
||||
case KEYSCALE_NAME: state = KEYSCALE_VALUE; break;
|
||||
case LANGUAGE_NAME: state = LANGUAGE_VALUE; break;
|
||||
|
|
@ -780,6 +789,7 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
|
||||
int V = m->cfg.vocab_size;
|
||||
bool use_cfg = cfg_scale > 1.0f && uncond_tokens && !uncond_tokens->empty();
|
||||
forced_tokens.clear();
|
||||
|
||||
// KV sets: cond [0..N-1], uncond [N..2N-1] if CFG
|
||||
for (int i = 0; i < N; i++) qw3lm_reset_kv(m, i);
|
||||
|
|
@ -836,7 +846,7 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
if (fsm_template && fsm_template->enabled)
|
||||
seqs[i].fsm.apply_mask(lg.data());
|
||||
|
||||
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.04f,top_p,30,temperature,acestep_lm_rng);
|
||||
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,40,temperature,acestep_lm_rng);
|
||||
|
||||
if (tok == TOKEN_IM_END) {
|
||||
seqs[i].done = true;
|
||||
|
|
@ -881,7 +891,6 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
if (seqs[i].done) n_active--;
|
||||
|
||||
std::vector<int32_t> quicklastntoks;
|
||||
std::vector<int32_t> forced_tokens;
|
||||
|
||||
for (int step = 0; step < max_new_tokens && n_active > 0; step++) {
|
||||
for (int i = 0; i < N; i++)
|
||||
|
|
@ -929,7 +938,7 @@ static std::vector<std::string> generate_phase1_batch(
|
|||
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
||||
}
|
||||
|
||||
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.04f,top_p,30,temperature,acestep_lm_rng);
|
||||
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,40,temperature,acestep_lm_rng);
|
||||
quicklastntoks.push_back(tok);
|
||||
if (quicklastntoks.size()>32) {
|
||||
quicklastntoks.erase(quicklastntoks.begin());
|
||||
|
|
@ -1032,6 +1041,12 @@ static std::vector<std::string> run_phase2_batch(
|
|||
Timer t_prefill;
|
||||
std::vector<std::vector<float>> prefill_logits_vec(N, std::vector<float>(V));
|
||||
|
||||
if(acestep_lm_dbg)
|
||||
{
|
||||
std::string tks = bpe_decode(bpe,prompts[0]);
|
||||
printf("\nPhase2: UseCFG:%d, Promptsiz:%d, Prompt: %s",use_cfg,prompts[0].size(),tks.c_str());
|
||||
}
|
||||
|
||||
if (shared_prompt) {
|
||||
qw3lm_forward(m, prompts[0].data(), (int)prompts[0].size(), 0, prefill_logits_vec[0].data());
|
||||
for (int i = 1; i < N; i++) {
|
||||
|
|
@ -1086,7 +1101,7 @@ static std::vector<std::string> run_phase2_batch(
|
|||
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
||||
if (v != TOKEN_IM_END) lg[v] = -1e9f;
|
||||
|
||||
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,30,temperature,acestep_lm_rng);
|
||||
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,40,temperature,acestep_lm_rng);
|
||||
seqs[i].last_token = tok;
|
||||
|
||||
if (tok == TOKEN_IM_END) {
|
||||
|
|
@ -1158,7 +1173,7 @@ static std::vector<std::string> run_phase2_batch(
|
|||
for (int v = 0; v < AUDIO_CODE_BASE; v++)
|
||||
if (v != TOKEN_IM_END) lc[v] = -1e9f;
|
||||
|
||||
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,30,temperature,acestep_lm_rng);
|
||||
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,40,temperature,acestep_lm_rng);
|
||||
quicklastntoks.push_back(tok);
|
||||
if (quicklastntoks.size()>32) {
|
||||
quicklastntoks.erase(quicklastntoks.begin());
|
||||
|
|
@ -1576,6 +1591,12 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
std::vector<int> prompt;
|
||||
std::vector<AcePrompt> aces; // populated by Phase 1 (simple or partial)
|
||||
|
||||
caption_tokens.clear();
|
||||
// if(ace.caption!="")
|
||||
// {
|
||||
// caption_tokens = bpe_encode(&acestep_bpe, ace.caption+"\n", false);
|
||||
// }
|
||||
|
||||
// Preprocessor: simple mode generates lyrics + metas from caption
|
||||
if (is_simple) {
|
||||
fprintf(stderr, "[Simple] Inspiration\n");
|
||||
|
|
@ -1584,7 +1605,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
"# Instruction\n"
|
||||
"Expand the user's input into a more detailed"
|
||||
" and specific musical description:\n";
|
||||
std::string user_msg = ace.caption + "\n\ninstrumental: "
|
||||
std::string user_msg = "# Caption\n"+ace.caption + "\n\ninstrumental: "
|
||||
+ std::string(req.instrumental ? "true" : "false");
|
||||
prompt = build_custom_prompt(acestep_bpe, sys, user_msg.c_str());
|
||||
|
||||
|
|
@ -1631,6 +1652,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
|
|||
|
||||
for (int i = 0; i < 2 * batch_size; i++) qw3lm_reset_kv(&acestep_llm, i);
|
||||
}
|
||||
fsm.reset();
|
||||
|
||||
// Guarantee aces is populated (all-metas: single shared ace for prefill optimization)
|
||||
if (aces.empty()) {
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs
|
|||
|
||||
if (inputs.is_planner_mode && musicgen_llm_loaded) {
|
||||
if (!music_is_quiet) {
|
||||
printf("\nMusic Gen Generating Codes...");
|
||||
printf("\nMusic Gen Generating Codes...\n");
|
||||
}
|
||||
music_output_json_str = acestep_prepare_request(inputs);
|
||||
if(music_output_json_str=="")
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ void request_init(AceRequest * r) {
|
|||
r->task_type = "text2music";
|
||||
r->seed = -1;
|
||||
r->thinking = false;
|
||||
r->lm_temperature = 0.85f;
|
||||
r->lm_temperature = 1.0f;
|
||||
r->lm_cfg_scale = 2.0f;
|
||||
r->lm_top_p = 0.9f;
|
||||
r->lm_top_p = 0.95f;
|
||||
r->lm_top_k = 0;
|
||||
r->lm_negative_prompt = "";
|
||||
r->audio_codes = "";
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
|
|||
|
||||
if(speaker_embedding.size()==0)
|
||||
{
|
||||
printf("Creating Voice Embedding ID=%u... (Warning, lengthy sample audio will be very slow. Use short clips!)\n",reuse_hash_val);
|
||||
printf("Creating Voice Embedding ID=%u... (Warning, lengthy sample audio will take longer to load. Short clips recommended!)\n",reuse_hash_val);
|
||||
if (!audio_encoder_.encode(ref_samples, n_ref_samples, speaker_embedding)) {
|
||||
result.error_msg = "Failed to extract speaker embedding: " + audio_encoder_.get_error();
|
||||
return result;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue