diff --git a/otherarch/acestep/cond.h b/otherarch/acestep/cond.h index 9581806dd..d0f0542b0 100644 --- a/otherarch/acestep/cond.h +++ b/otherarch/acestep/cond.h @@ -59,6 +59,12 @@ struct CondGGML { struct ggml_tensor * timbre_embed_b; // [2048] struct ggml_tensor * timbre_norm; // [2048] + // XL models prepend a learned CLS token to the timbre sequence. + // The CLS output (position 0) aggregates timbre info across all frames. + // 2B models skip this (the token exists in the GGUF but is unused). + struct ggml_tensor * timbre_cls; // [2048, 1, 1] or NULL + bool use_timbre_cls; + // Text projector: Linear(1024->2048) no bias struct ggml_tensor * text_proj_w; // [2048, 1024] ggml @@ -93,12 +99,16 @@ static bool cond_ggml_load(CondGGML * m, const char * gguf_path) { return false; } + // XL models have encoder_hidden_size in GGUF metadata (2B models omit it). + // When present, the timbre encoder prepends a learned CLS token. + m->use_timbre_cls = (gf_get_u32(gf, "acestep.encoder_hidden_size") > 0); + // Count tensors: // lyric: embed_w(1) + embed_b(1) + 8 layers x 11(88) + norm(1) = 91 - // timbre: embed_w(1) + embed_b(1) + 4 layers x 11(44) + norm(1) = 47 + // timbre: embed_w(1) + embed_b(1) + 4 layers x 11(44) + norm(1) + cls(0 or 1) // text_proj(1) + null_cond(1) = 2 - // Total: 140 - int n_tensors = 91 + 47 + 2; + int n_tensors = 91 + 47 + 2 + (m->use_timbre_cls ? 1 : 0); + wctx_init(&m->wctx, n_tensors); // Lyric encoder @@ -123,6 +133,12 @@ static bool cond_ggml_load(CondGGML * m, const char * gguf_path) { qwen3_load_layer(&m->wctx, gf, &m->timbre_layers[i], prefix, i); } + // Timbre CLS token (XL only) + m->timbre_cls = NULL; + if (m->use_timbre_cls) { + m->timbre_cls = gf_load_tensor_f32(&m->wctx, gf, "encoder.timbre_encoder.special_token"); + } + // Text projector + null condition m->text_proj_w = gf_load_tensor(&m->wctx, gf, "encoder.text_projector.weight"); m->null_cond_emb = gf_load_tensor(&m->wctx, gf, "null_condition_emb"); @@ -133,8 +149,8 @@ static bool cond_ggml_load(CondGGML * m, const char * gguf_path) { } gf_close(&gf); - fprintf(stderr, "[Load] CondEncoder: lyric(%dL), timbre(%dL), text_proj, null_cond\n", - m->lyric_cfg.n_layers, m->timbre_cfg.n_layers); + fprintf(stderr, "[Load] CondEncoder: lyric(%dL), timbre(%dL%s), text_proj, null_cond\n", m->lyric_cfg.n_layers, + m->timbre_cfg.n_layers, m->use_timbre_cls ? ", CLS" : ""); return true; } @@ -159,6 +175,7 @@ static void cond_ggml_forward(CondGGML * m, int * out_enc_S) { int H = 2048; bool has_timbre = (timbre_feats != NULL && S_ref > 0); + int S_timbre = has_timbre ? S_ref + (m->use_timbre_cls ? 1 : 0) : 0; // Graph context (generous fixed allocation) size_t ctx_size = 4096 * ggml_tensor_overhead() + ggml_graph_overhead(); @@ -215,7 +232,7 @@ static void cond_ggml_forward(CondGGML * m, struct ggml_tensor * timbre_slide_mask = NULL; if (has_timbre) { - timbre_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S_ref); + timbre_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S_timbre); ggml_set_name(timbre_pos, "timbre_pos"); ggml_set_input(timbre_pos); @@ -223,12 +240,19 @@ static void cond_ggml_forward(CondGGML * m, ggml_set_name(t_timbre_in, "timbre_in"); ggml_set_input(t_timbre_in); - // Linear embed: [64, S_ref] -> [2048, S_ref] + // Linear embed: [64, S_ref] -> [H, S_ref] struct ggml_tensor * timbre_h = qwen3_linear_bias(ctx, m->timbre_embed_w, m->timbre_embed_b, t_timbre_in); + // XL: prepend learned CLS token -> [H, S_ref+1] + // The CLS output at position 0 aggregates timbre across all frames. + if (m->use_timbre_cls) { + struct ggml_tensor * cls = ggml_reshape_2d(ctx, m->timbre_cls, H, 1); + timbre_h = ggml_concat(ctx, cls, timbre_h, 1); + } + // Bidirectional sliding window mask for even layers - timbre_slide_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S_ref, S_ref); + timbre_slide_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S_timbre, S_timbre); ggml_set_name(timbre_slide_mask, "timbre_slide_mask"); ggml_set_input(timbre_slide_mask); @@ -236,11 +260,12 @@ static void cond_ggml_forward(CondGGML * m, for (int i = 0; i < m->timbre_cfg.n_layers; i++) { struct ggml_tensor * layer_mask = (i % 2 == 0) ? timbre_slide_mask : NULL; timbre_h = qwen3_build_layer(ctx, m->timbre_cfg, &m->timbre_layers[i], - timbre_h, timbre_pos, layer_mask, S_ref); + timbre_h, timbre_pos, layer_mask, S_timbre); } timbre_h = qwen3_rms_norm(ctx, timbre_h, m->timbre_norm, m->timbre_cfg.rms_norm_eps); - // Take first frame: [2048, S_ref] -> view [2048, 1] + // Take first position: [H, S_timbre] -> view [H, 1] + // 2B: first audio frame. XL: CLS token (aggregated timbre). timbre_out = ggml_view_2d(ctx, timbre_h, H, 1, timbre_h->nb[1], 0); ggml_set_name(timbre_out, "timbre_out"); @@ -282,22 +307,27 @@ static void cond_ggml_forward(CondGGML * m, if (has_timbre) { ggml_backend_tensor_set(t_timbre_in, timbre_feats, 0, 64 * S_ref * sizeof(float)); - std::vector pos(S_ref); - for (int i = 0; i < S_ref; i++) pos[i] = i; - ggml_backend_tensor_set(timbre_pos, pos.data(), 0, S_ref * sizeof(int)); + std::vector pos(S_timbre); + for (int i = 0; i < S_timbre; i++) { + pos[i] = i; + } + ggml_backend_tensor_set(timbre_pos, pos.data(), 0, S_timbre * sizeof(int)); // Timbre sliding window mask: bidirectional, |i-j| <= 128 - const int W = 128; - std::vector mask_data(S_ref * S_ref); - for (int i = 0; i < S_ref; i++) { - for (int j = 0; j < S_ref; j++) { - int d = i - j; if (d < 0) d = -d; - mask_data[i * S_ref + j] = ggml_fp32_to_fp16(d <= W ? 0.0f : -INFINITY); + const int W = 128; + std::vector mask_data(S_timbre * S_timbre); + for (int i = 0; i < S_timbre; i++) { + for (int j = 0; j < S_timbre; j++) { + int d = i - j; + if (d < 0) { + d = -d; + } + mask_data[i * S_timbre + j] = ggml_fp32_to_fp16(d <= W ? 0.0f : -INFINITY); } } - ggml_backend_tensor_set(timbre_slide_mask, mask_data.data(), 0, - S_ref * S_ref * sizeof(uint16_t)); - fprintf(stderr, "[CondEnc] Timbre sliding mask: %dx%d, window=%d\n", S_ref, S_ref, W); + ggml_backend_tensor_set(timbre_slide_mask, mask_data.data(), 0, S_timbre * S_timbre * sizeof(uint16_t)); + fprintf(stderr, "[CondEnc] Timbre sliding mask: %dx%d, window=%d%s\n", S_timbre, S_timbre, W, + m->use_timbre_cls ? " (CLS)" : ""); } // Compute diff --git a/otherarch/acestep/dit.h b/otherarch/acestep/dit.h index 8fd57e775..52f498f41 100644 --- a/otherarch/acestep/dit.h +++ b/otherarch/acestep/dit.h @@ -98,9 +98,9 @@ struct DiTGGML { struct ggml_tensor * proj_in_w; // [in_ch*P, H] pre-permuted F32 struct ggml_tensor * proj_in_b; // [hidden] - // condition_embedder: Linear(hidden, hidden) - struct ggml_tensor * cond_emb_w; // [hidden, hidden] - struct ggml_tensor * cond_emb_b; // [hidden] + // condition_embedder: Linear(encoder_H, decoder_H) + struct ggml_tensor * cond_emb_w; // [encoder_H, decoder_H] projects encoder to decoder space + struct ggml_tensor * cond_emb_b; // [decoder_H] // Layers DiTGGMLLayer layers[DIT_GGML_MAX_LAYERS]; @@ -851,8 +851,11 @@ static struct ggml_cgraph * dit_ggml_build_graph( ggml_set_input(input); *p_input = input; - // Encoder hidden states: [H, enc_S, N] - struct ggml_tensor * enc_hidden = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, enc_S, N); + // Encoder hidden states: [H_enc, enc_S, N] + // H_enc comes from the condition_embedder input dimension (2048 for both 2B and XL). + // The condition_embedder projects H_enc -> H (decoder) via cond_emb_w. + int H_enc = (int) m->cond_emb_w->ne[0]; + struct ggml_tensor * enc_hidden = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H_enc, enc_S, N); ggml_set_name(enc_hidden, "enc_hidden"); ggml_set_input(enc_hidden); @@ -1069,7 +1072,7 @@ static void apg_forward( // // noise: [N * T * Oc] N contiguous [T, Oc] noise blocks // context_latents: [N * T * ctx_ch] N contiguous context blocks -// enc_hidden: [enc_S * H] SINGLE encoder output (shared, will be broadcast to N) +// enc_hidden: [enc_S * H_enc * N] per-batch encoder outputs (caller-stacked) // schedule: array of num_steps timestep values // output: [N * T * Oc] generated latents (caller-allocated) static void dit_ggml_generate( @@ -1095,7 +1098,6 @@ static void dit_ggml_generate( int S = T / c.patch_size; int n_per = T * Oc; // elements per sample int n_total = N * n_per; // total output elements - int H = c.hidden_size; fprintf(stderr, "[DiT] Batch N=%d, T=%d, S=%d, enc_S=%d\n", N, T, S, enc_S); @@ -1118,6 +1120,7 @@ static void dit_ggml_generate( fprintf(stderr, "[DiT] Graph: %d nodes\n", ggml_graph_n_nodes(gf)); struct ggml_tensor * t_enc = ggml_graph_get_tensor(gf, "enc_hidden"); + int H_enc = (int) t_enc->ne[0]; // encoder hidden size (from condition_embedder) // Allocate compute buffers. // Critical: reset FIRST (clears old state), THEN force inputs to GPU, THEN alloc. @@ -1196,17 +1199,19 @@ static void dit_ggml_generate( ggml_backend_tensor_get(model->null_condition_emb, null_emb.data(), 0, emb_n * sizeof(float)); } - // Broadcast [H] to [enc_S, H] then to N copies [H, enc_S, N] - std::vector null_enc_single(H * enc_S); - for (int s = 0; s < enc_S; s++) - memcpy(&null_enc_single[s * H], null_emb.data(), H * sizeof(float)); - null_enc_buf.resize(H * enc_S * N); - for (int b = 0; b < N; b++) - memcpy(null_enc_buf.data() + b * enc_S * H, null_enc_single.data(), enc_S * H * sizeof(float)); + // Broadcast [H_enc] to [enc_S, H_enc] then to N copies [H_enc, enc_S, N] + std::vector null_enc_single(H_enc * enc_S); + for (int s = 0; s < enc_S; s++) { + memcpy(&null_enc_single[s * H_enc], null_emb.data(), H_enc * sizeof(float)); + } + null_enc_buf.resize(H_enc * enc_S * N); + for (int b = 0; b < N; b++) { + memcpy(null_enc_buf.data() + b * enc_S * H_enc, null_enc_single.data(), enc_S * H_enc * sizeof(float)); + } if (dbg && dbg->enabled) { debug_dump_1d(dbg, "null_condition_emb", null_emb.data(), emb_n); - debug_dump_2d(dbg, "null_enc_hidden", null_enc_single.data(), enc_S, H); + debug_dump_2d(dbg, "null_enc_hidden", null_enc_single.data(), enc_S, H_enc); } apg_mbufs.resize(N); @@ -1235,9 +1240,9 @@ static void dit_ggml_generate( ctx_ch * sizeof(float)); // Pre-allocate enc_buf once (avoids heap alloc per step) - std::vector enc_buf(H * enc_S * N); + std::vector enc_buf(H_enc * enc_S * N); for (int b = 0; b < N; b++) - memcpy(enc_buf.data() + b * enc_S * H, enc_hidden_data, enc_S * H * sizeof(float)); + memcpy(enc_buf.data() + b * enc_S * H_enc, enc_hidden_data, enc_S * H_enc * sizeof(float)); ggml_backend_tensor_set(t_enc, enc_buf.data(), 0, enc_buf.size() * sizeof(float)); struct ggml_tensor * t_t = ggml_graph_get_tensor(gf, "t"); @@ -1340,7 +1345,7 @@ static void dit_ggml_generate( } // Unconditional pass: re-upload all inputs (scheduler clobbers input buffers during compute) - ggml_backend_tensor_set(t_enc, null_enc_buf.data(), 0, H * enc_S * N * sizeof(float)); + ggml_backend_tensor_set(t_enc, null_enc_buf.data(), 0, H_enc * enc_S * N * sizeof(float)); ggml_backend_tensor_set(t_input, input_buf.data(), 0, in_ch * T * N * sizeof(float)); if (t_t) ggml_backend_tensor_set(t_t, &t_curr, 0, sizeof(float)); if (t_tr) ggml_backend_tensor_set(t_tr, &t_curr, 0, sizeof(float));