ace step xl tentative changes (not yet working)

This commit is contained in:
Concedo 2026-04-08 18:00:39 +08:00
parent d9ed4b444b
commit 4b478b70fa
2 changed files with 75 additions and 40 deletions

View file

@ -59,6 +59,12 @@ struct CondGGML {
struct ggml_tensor * timbre_embed_b; // [2048]
struct ggml_tensor * timbre_norm; // [2048]
// XL models prepend a learned CLS token to the timbre sequence.
// The CLS output (position 0) aggregates timbre info across all frames.
// 2B models skip this (the token exists in the GGUF but is unused).
struct ggml_tensor * timbre_cls; // [2048, 1, 1] or NULL
bool use_timbre_cls;
// Text projector: Linear(1024->2048) no bias
struct ggml_tensor * text_proj_w; // [2048, 1024] ggml
@ -93,12 +99,16 @@ static bool cond_ggml_load(CondGGML * m, const char * gguf_path) {
return false;
}
// XL models have encoder_hidden_size in GGUF metadata (2B models omit it).
// When present, the timbre encoder prepends a learned CLS token.
m->use_timbre_cls = (gf_get_u32(gf, "acestep.encoder_hidden_size") > 0);
// Count tensors:
// lyric: embed_w(1) + embed_b(1) + 8 layers x 11(88) + norm(1) = 91
// timbre: embed_w(1) + embed_b(1) + 4 layers x 11(44) + norm(1) = 47
// timbre: embed_w(1) + embed_b(1) + 4 layers x 11(44) + norm(1) + cls(0 or 1)
// text_proj(1) + null_cond(1) = 2
// Total: 140
int n_tensors = 91 + 47 + 2;
int n_tensors = 91 + 47 + 2 + (m->use_timbre_cls ? 1 : 0);
wctx_init(&m->wctx, n_tensors);
// Lyric encoder
@ -123,6 +133,12 @@ static bool cond_ggml_load(CondGGML * m, const char * gguf_path) {
qwen3_load_layer(&m->wctx, gf, &m->timbre_layers[i], prefix, i);
}
// Timbre CLS token (XL only)
m->timbre_cls = NULL;
if (m->use_timbre_cls) {
m->timbre_cls = gf_load_tensor_f32(&m->wctx, gf, "encoder.timbre_encoder.special_token");
}
// Text projector + null condition
m->text_proj_w = gf_load_tensor(&m->wctx, gf, "encoder.text_projector.weight");
m->null_cond_emb = gf_load_tensor(&m->wctx, gf, "null_condition_emb");
@ -133,8 +149,8 @@ static bool cond_ggml_load(CondGGML * m, const char * gguf_path) {
}
gf_close(&gf);
fprintf(stderr, "[Load] CondEncoder: lyric(%dL), timbre(%dL), text_proj, null_cond\n",
m->lyric_cfg.n_layers, m->timbre_cfg.n_layers);
fprintf(stderr, "[Load] CondEncoder: lyric(%dL), timbre(%dL%s), text_proj, null_cond\n", m->lyric_cfg.n_layers,
m->timbre_cfg.n_layers, m->use_timbre_cls ? ", CLS" : "");
return true;
}
@ -159,6 +175,7 @@ static void cond_ggml_forward(CondGGML * m,
int * out_enc_S) {
int H = 2048;
bool has_timbre = (timbre_feats != NULL && S_ref > 0);
int S_timbre = has_timbre ? S_ref + (m->use_timbre_cls ? 1 : 0) : 0;
// Graph context (generous fixed allocation)
size_t ctx_size = 4096 * ggml_tensor_overhead() + ggml_graph_overhead();
@ -215,7 +232,7 @@ static void cond_ggml_forward(CondGGML * m,
struct ggml_tensor * timbre_slide_mask = NULL;
if (has_timbre) {
timbre_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S_ref);
timbre_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S_timbre);
ggml_set_name(timbre_pos, "timbre_pos");
ggml_set_input(timbre_pos);
@ -223,12 +240,19 @@ static void cond_ggml_forward(CondGGML * m,
ggml_set_name(t_timbre_in, "timbre_in");
ggml_set_input(t_timbre_in);
// Linear embed: [64, S_ref] -> [2048, S_ref]
// Linear embed: [64, S_ref] -> [H, S_ref]
struct ggml_tensor * timbre_h = qwen3_linear_bias(ctx, m->timbre_embed_w,
m->timbre_embed_b, t_timbre_in);
// XL: prepend learned CLS token -> [H, S_ref+1]
// The CLS output at position 0 aggregates timbre across all frames.
if (m->use_timbre_cls) {
struct ggml_tensor * cls = ggml_reshape_2d(ctx, m->timbre_cls, H, 1);
timbre_h = ggml_concat(ctx, cls, timbre_h, 1);
}
// Bidirectional sliding window mask for even layers
timbre_slide_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S_ref, S_ref);
timbre_slide_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S_timbre, S_timbre);
ggml_set_name(timbre_slide_mask, "timbre_slide_mask");
ggml_set_input(timbre_slide_mask);
@ -236,11 +260,12 @@ static void cond_ggml_forward(CondGGML * m,
for (int i = 0; i < m->timbre_cfg.n_layers; i++) {
struct ggml_tensor * layer_mask = (i % 2 == 0) ? timbre_slide_mask : NULL;
timbre_h = qwen3_build_layer(ctx, m->timbre_cfg, &m->timbre_layers[i],
timbre_h, timbre_pos, layer_mask, S_ref);
timbre_h, timbre_pos, layer_mask, S_timbre);
}
timbre_h = qwen3_rms_norm(ctx, timbre_h, m->timbre_norm, m->timbre_cfg.rms_norm_eps);
// Take first frame: [2048, S_ref] -> view [2048, 1]
// Take first position: [H, S_timbre] -> view [H, 1]
// 2B: first audio frame. XL: CLS token (aggregated timbre).
timbre_out = ggml_view_2d(ctx, timbre_h, H, 1,
timbre_h->nb[1], 0);
ggml_set_name(timbre_out, "timbre_out");
@ -282,22 +307,27 @@ static void cond_ggml_forward(CondGGML * m,
if (has_timbre) {
ggml_backend_tensor_set(t_timbre_in, timbre_feats, 0, 64 * S_ref * sizeof(float));
std::vector<int> pos(S_ref);
for (int i = 0; i < S_ref; i++) pos[i] = i;
ggml_backend_tensor_set(timbre_pos, pos.data(), 0, S_ref * sizeof(int));
std::vector<int> pos(S_timbre);
for (int i = 0; i < S_timbre; i++) {
pos[i] = i;
}
ggml_backend_tensor_set(timbre_pos, pos.data(), 0, S_timbre * sizeof(int));
// Timbre sliding window mask: bidirectional, |i-j| <= 128
const int W = 128;
std::vector<uint16_t> mask_data(S_ref * S_ref);
for (int i = 0; i < S_ref; i++) {
for (int j = 0; j < S_ref; j++) {
int d = i - j; if (d < 0) d = -d;
mask_data[i * S_ref + j] = ggml_fp32_to_fp16(d <= W ? 0.0f : -INFINITY);
const int W = 128;
std::vector<uint16_t> mask_data(S_timbre * S_timbre);
for (int i = 0; i < S_timbre; i++) {
for (int j = 0; j < S_timbre; j++) {
int d = i - j;
if (d < 0) {
d = -d;
}
mask_data[i * S_timbre + j] = ggml_fp32_to_fp16(d <= W ? 0.0f : -INFINITY);
}
}
ggml_backend_tensor_set(timbre_slide_mask, mask_data.data(), 0,
S_ref * S_ref * sizeof(uint16_t));
fprintf(stderr, "[CondEnc] Timbre sliding mask: %dx%d, window=%d\n", S_ref, S_ref, W);
ggml_backend_tensor_set(timbre_slide_mask, mask_data.data(), 0, S_timbre * S_timbre * sizeof(uint16_t));
fprintf(stderr, "[CondEnc] Timbre sliding mask: %dx%d, window=%d%s\n", S_timbre, S_timbre, W,
m->use_timbre_cls ? " (CLS)" : "");
}
// Compute

View file

@ -98,9 +98,9 @@ struct DiTGGML {
struct ggml_tensor * proj_in_w; // [in_ch*P, H] pre-permuted F32
struct ggml_tensor * proj_in_b; // [hidden]
// condition_embedder: Linear(hidden, hidden)
struct ggml_tensor * cond_emb_w; // [hidden, hidden]
struct ggml_tensor * cond_emb_b; // [hidden]
// condition_embedder: Linear(encoder_H, decoder_H)
struct ggml_tensor * cond_emb_w; // [encoder_H, decoder_H] projects encoder to decoder space
struct ggml_tensor * cond_emb_b; // [decoder_H]
// Layers
DiTGGMLLayer layers[DIT_GGML_MAX_LAYERS];
@ -851,8 +851,11 @@ static struct ggml_cgraph * dit_ggml_build_graph(
ggml_set_input(input);
*p_input = input;
// Encoder hidden states: [H, enc_S, N]
struct ggml_tensor * enc_hidden = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H, enc_S, N);
// Encoder hidden states: [H_enc, enc_S, N]
// H_enc comes from the condition_embedder input dimension (2048 for both 2B and XL).
// The condition_embedder projects H_enc -> H (decoder) via cond_emb_w.
int H_enc = (int) m->cond_emb_w->ne[0];
struct ggml_tensor * enc_hidden = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H_enc, enc_S, N);
ggml_set_name(enc_hidden, "enc_hidden");
ggml_set_input(enc_hidden);
@ -1069,7 +1072,7 @@ static void apg_forward(
//
// noise: [N * T * Oc] N contiguous [T, Oc] noise blocks
// context_latents: [N * T * ctx_ch] N contiguous context blocks
// enc_hidden: [enc_S * H] SINGLE encoder output (shared, will be broadcast to N)
// enc_hidden: [enc_S * H_enc * N] per-batch encoder outputs (caller-stacked)
// schedule: array of num_steps timestep values
// output: [N * T * Oc] generated latents (caller-allocated)
static void dit_ggml_generate(
@ -1095,7 +1098,6 @@ static void dit_ggml_generate(
int S = T / c.patch_size;
int n_per = T * Oc; // elements per sample
int n_total = N * n_per; // total output elements
int H = c.hidden_size;
fprintf(stderr, "[DiT] Batch N=%d, T=%d, S=%d, enc_S=%d\n", N, T, S, enc_S);
@ -1118,6 +1120,7 @@ static void dit_ggml_generate(
fprintf(stderr, "[DiT] Graph: %d nodes\n", ggml_graph_n_nodes(gf));
struct ggml_tensor * t_enc = ggml_graph_get_tensor(gf, "enc_hidden");
int H_enc = (int) t_enc->ne[0]; // encoder hidden size (from condition_embedder)
// Allocate compute buffers.
// Critical: reset FIRST (clears old state), THEN force inputs to GPU, THEN alloc.
@ -1196,17 +1199,19 @@ static void dit_ggml_generate(
ggml_backend_tensor_get(model->null_condition_emb, null_emb.data(), 0, emb_n * sizeof(float));
}
// Broadcast [H] to [enc_S, H] then to N copies [H, enc_S, N]
std::vector<float> null_enc_single(H * enc_S);
for (int s = 0; s < enc_S; s++)
memcpy(&null_enc_single[s * H], null_emb.data(), H * sizeof(float));
null_enc_buf.resize(H * enc_S * N);
for (int b = 0; b < N; b++)
memcpy(null_enc_buf.data() + b * enc_S * H, null_enc_single.data(), enc_S * H * sizeof(float));
// Broadcast [H_enc] to [enc_S, H_enc] then to N copies [H_enc, enc_S, N]
std::vector<float> null_enc_single(H_enc * enc_S);
for (int s = 0; s < enc_S; s++) {
memcpy(&null_enc_single[s * H_enc], null_emb.data(), H_enc * sizeof(float));
}
null_enc_buf.resize(H_enc * enc_S * N);
for (int b = 0; b < N; b++) {
memcpy(null_enc_buf.data() + b * enc_S * H_enc, null_enc_single.data(), enc_S * H_enc * sizeof(float));
}
if (dbg && dbg->enabled) {
debug_dump_1d(dbg, "null_condition_emb", null_emb.data(), emb_n);
debug_dump_2d(dbg, "null_enc_hidden", null_enc_single.data(), enc_S, H);
debug_dump_2d(dbg, "null_enc_hidden", null_enc_single.data(), enc_S, H_enc);
}
apg_mbufs.resize(N);
@ -1235,9 +1240,9 @@ static void dit_ggml_generate(
ctx_ch * sizeof(float));
// Pre-allocate enc_buf once (avoids heap alloc per step)
std::vector<float> enc_buf(H * enc_S * N);
std::vector<float> enc_buf(H_enc * enc_S * N);
for (int b = 0; b < N; b++)
memcpy(enc_buf.data() + b * enc_S * H, enc_hidden_data, enc_S * H * sizeof(float));
memcpy(enc_buf.data() + b * enc_S * H_enc, enc_hidden_data, enc_S * H_enc * sizeof(float));
ggml_backend_tensor_set(t_enc, enc_buf.data(), 0, enc_buf.size() * sizeof(float));
struct ggml_tensor * t_t = ggml_graph_get_tensor(gf, "t");
@ -1340,7 +1345,7 @@ static void dit_ggml_generate(
}
// Unconditional pass: re-upload all inputs (scheduler clobbers input buffers during compute)
ggml_backend_tensor_set(t_enc, null_enc_buf.data(), 0, H * enc_S * N * sizeof(float));
ggml_backend_tensor_set(t_enc, null_enc_buf.data(), 0, H_enc * enc_S * N * sizeof(float));
ggml_backend_tensor_set(t_input, input_buf.data(), 0, in_ch * T * N * sizeof(float));
if (t_t) ggml_backend_tensor_set(t_t, &t_curr, 0, sizeof(float));
if (t_tr) ggml_backend_tensor_set(t_tr, &t_curr, 0, sizeof(float));