adjust ace step, still wip on caption rework

This commit is contained in:
Concedo 2026-03-09 00:11:48 +08:00
parent 9ddd74111f
commit 45c74da08b
4 changed files with 37 additions and 15 deletions

View file

@ -171,6 +171,8 @@ struct AcePrompt {
static std::mt19937 acestep_lm_rng;
static bool acestep_lm_dbg = false;
static std::vector<int32_t> forced_tokens;
static std::vector<int> caption_tokens = std::vector<int>(); //will be filled with caption tokens
//
// CoT parsing (extract metadata + lyrics from LLM Phase1 output)
@ -347,8 +349,8 @@ static std::string build_cot_yaml(const AcePrompt & prompt) {
std::string yaml;
if (prompt.bpm > 0)
yaml += "bpm: " + std::to_string(prompt.bpm) + "\n";
if (!prompt.caption.empty())
yaml += yaml_wrap("caption", prompt.caption);
// if (!prompt.caption.empty())
// yaml += yaml_wrap("caption", prompt.caption);
if (prompt.duration > 0)
yaml += "duration: " + std::to_string((int)prompt.duration) + "\n";
if (!prompt.keyscale.empty())
@ -529,10 +531,10 @@ struct MetadataFSM {
for (int v = 30; v <= 300; v++) vals.push_back(std::to_string(v));
build_value_tree(bpe, bpm_tree, "bpm:", vals);
}
// Duration 10-600
// Duration 10-300
{
std::vector<std::string> vals;
for (int v = 10; v <= 600; v++) vals.push_back(std::to_string(v));
for (int v = 10; v <= 300; v++) vals.push_back(std::to_string(v));
build_value_tree(bpe, duration_tree, "duration:", vals);
}
// Keyscale
@ -674,7 +676,14 @@ struct MetadataFSM {
if (name_pos >= (int)name->size()) {
switch (state) {
case BPM_NAME: state = BPM_VALUE; break;
case CAPTION_NAME: state = CAPTION_VALUE; break;
case CAPTION_NAME:
state = CAPTION_VALUE;
if(caption_tokens.size()>0)
{
forced_tokens.clear();
forced_tokens = caption_tokens;
}
break;
case DURATION_NAME: state = DURATION_VALUE; break;
case KEYSCALE_NAME: state = KEYSCALE_VALUE; break;
case LANGUAGE_NAME: state = LANGUAGE_VALUE; break;
@ -780,6 +789,7 @@ static std::vector<std::string> generate_phase1_batch(
int V = m->cfg.vocab_size;
bool use_cfg = cfg_scale > 1.0f && uncond_tokens && !uncond_tokens->empty();
forced_tokens.clear();
// KV sets: cond [0..N-1], uncond [N..2N-1] if CFG
for (int i = 0; i < N; i++) qw3lm_reset_kv(m, i);
@ -836,7 +846,7 @@ static std::vector<std::string> generate_phase1_batch(
if (fsm_template && fsm_template->enabled)
seqs[i].fsm.apply_mask(lg.data());
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.04f,top_p,30,temperature,acestep_lm_rng);
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,40,temperature,acestep_lm_rng);
if (tok == TOKEN_IM_END) {
seqs[i].done = true;
@ -881,7 +891,6 @@ static std::vector<std::string> generate_phase1_batch(
if (seqs[i].done) n_active--;
std::vector<int32_t> quicklastntoks;
std::vector<int32_t> forced_tokens;
for (int step = 0; step < max_new_tokens && n_active > 0; step++) {
for (int i = 0; i < N; i++)
@ -929,7 +938,7 @@ static std::vector<std::string> generate_phase1_batch(
if (v != TOKEN_IM_END) lc[v] = -1e9f;
}
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.04f,top_p,30,temperature,acestep_lm_rng);
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,40,temperature,acestep_lm_rng);
quicklastntoks.push_back(tok);
if (quicklastntoks.size()>32) {
quicklastntoks.erase(quicklastntoks.begin());
@ -1032,6 +1041,12 @@ static std::vector<std::string> run_phase2_batch(
Timer t_prefill;
std::vector<std::vector<float>> prefill_logits_vec(N, std::vector<float>(V));
if(acestep_lm_dbg)
{
std::string tks = bpe_decode(bpe,prompts[0]);
printf("\nPhase2: UseCFG:%d, Promptsiz:%d, Prompt: %s",use_cfg,prompts[0].size(),tks.c_str());
}
if (shared_prompt) {
qw3lm_forward(m, prompts[0].data(), (int)prompts[0].size(), 0, prefill_logits_vec[0].data());
for (int i = 1; i < N; i++) {
@ -1086,7 +1101,7 @@ static std::vector<std::string> run_phase2_batch(
for (int v = 0; v < AUDIO_CODE_BASE; v++)
if (v != TOKEN_IM_END) lg[v] = -1e9f;
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,30,temperature,acestep_lm_rng);
int tok = kcpp_quick_sample(lg.data(),V,std::vector<int32_t>(),1.03f,top_p,40,temperature,acestep_lm_rng);
seqs[i].last_token = tok;
if (tok == TOKEN_IM_END) {
@ -1158,7 +1173,7 @@ static std::vector<std::string> run_phase2_batch(
for (int v = 0; v < AUDIO_CODE_BASE; v++)
if (v != TOKEN_IM_END) lc[v] = -1e9f;
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,30,temperature,acestep_lm_rng);
int tok = kcpp_quick_sample(lc,V,quicklastntoks,1.03f,top_p,40,temperature,acestep_lm_rng);
quicklastntoks.push_back(tok);
if (quicklastntoks.size()>32) {
quicklastntoks.erase(quicklastntoks.begin());
@ -1576,6 +1591,12 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
std::vector<int> prompt;
std::vector<AcePrompt> aces; // populated by Phase 1 (simple or partial)
caption_tokens.clear();
// if(ace.caption!="")
// {
// caption_tokens = bpe_encode(&acestep_bpe, ace.caption+"\n", false);
// }
// Preprocessor: simple mode generates lyrics + metas from caption
if (is_simple) {
fprintf(stderr, "[Simple] Inspiration\n");
@ -1584,7 +1605,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
"# Instruction\n"
"Expand the user's input into a more detailed"
" and specific musical description:\n";
std::string user_msg = ace.caption + "\n\ninstrumental: "
std::string user_msg = "# Caption\n"+ace.caption + "\n\ninstrumental: "
+ std::string(req.instrumental ? "true" : "false");
prompt = build_custom_prompt(acestep_bpe, sys, user_msg.c_str());
@ -1631,6 +1652,7 @@ std::string acestep_prepare_request(const music_generation_inputs inputs)
for (int i = 0; i < 2 * batch_size; i++) qw3lm_reset_kv(&acestep_llm, i);
}
fsm.reset();
// Guarantee aces is populated (all-metas: single shared ace for prefill optimization)
if (aces.empty()) {

View file

@ -113,7 +113,7 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs
if (inputs.is_planner_mode && musicgen_llm_loaded) {
if (!music_is_quiet) {
printf("\nMusic Gen Generating Codes...");
printf("\nMusic Gen Generating Codes...\n");
}
music_output_json_str = acestep_prepare_request(inputs);
if(music_output_json_str=="")

View file

@ -23,9 +23,9 @@ void request_init(AceRequest * r) {
r->task_type = "text2music";
r->seed = -1;
r->thinking = false;
r->lm_temperature = 0.85f;
r->lm_temperature = 1.0f;
r->lm_cfg_scale = 2.0f;
r->lm_top_p = 0.9f;
r->lm_top_p = 0.95f;
r->lm_top_k = 0;
r->lm_negative_prompt = "";
r->audio_codes = "";

View file

@ -244,7 +244,7 @@ tts_result Qwen3TTS::synthesize_with_voice(const std::string & text,
if(speaker_embedding.size()==0)
{
printf("Creating Voice Embedding ID=%u... (Warning, lengthy sample audio will be very slow. Use short clips!)\n",reuse_hash_val);
printf("Creating Voice Embedding ID=%u... (Warning, lengthy sample audio will take longer to load. Short clips recommended!)\n",reuse_hash_val);
if (!audio_encoder_.encode(ref_samples, n_ref_samples, speaker_embedding)) {
result.error_msg = "Failed to extract speaker embedding: " + audio_encoder_.get_error();
return result;