koboldcpp/otherarch/acestep/dit.h

1464 lines
61 KiB
C++

#pragma once
// dit.h: ACE-Step DiT (Diffusion Transformer) via ggml compute graph
// Ported from Python ACE-Step-1.5 reference. Same weights, loaded from GGUF.
//
// Architecture: 24-layer transformer with AdaLN, GQA self-attn + cross-attn, SwiGLU MLP.
// Flow matching: 8 Euler steps (turbo schedule).
//
// ggml ops used: rms_norm, mul_mat, rope_ext, flash_attn_ext, swiglu_split,
// conv_transpose_1d, add, mul, scale, view, reshape, permute.
#include "ggml.h"
#include "ggml-backend.h"
#include "ggml-alloc.h"
#include "gguf_weights.h"
#include "backend.h"
#include "debug.h"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>
// Config (populated from GGUF metadata by dit_ggml_load)
struct DiTGGMLConfig {
int hidden_size;
int intermediate_size;
int n_heads;
int n_kv_heads;
int head_dim;
int n_layers;
int in_channels;
int out_channels;
int patch_size;
int sliding_window;
float rope_theta;
float rms_norm_eps;
};
// Layer weights
struct DiTGGMLTembWeights {
struct ggml_tensor * linear_1_w; // [256, hidden]
struct ggml_tensor * linear_1_b; // [hidden]
struct ggml_tensor * linear_2_w; // [hidden, hidden]
struct ggml_tensor * linear_2_b; // [hidden]
struct ggml_tensor * time_proj_w; // [hidden, 6*hidden]
struct ggml_tensor * time_proj_b; // [6*hidden]
};
struct DiTGGMLLayer {
// Self-attention
struct ggml_tensor * self_attn_norm; // [hidden]
struct ggml_tensor * sa_qkv; // [hidden, (Nh+2*Nkv)*D] full fused (or NULL)
struct ggml_tensor * sa_qk; // [hidden, (Nh+Nkv)*D] partial QK fused (or NULL)
struct ggml_tensor * sa_q_proj; // separate fallback (NULL when any fusion active)
struct ggml_tensor * sa_k_proj;
struct ggml_tensor * sa_v_proj;
struct ggml_tensor * sa_q_norm; // [head_dim]
struct ggml_tensor * sa_k_norm; // [head_dim]
struct ggml_tensor * sa_o_proj; // [n_heads*head_dim, hidden]
// Cross-attention
struct ggml_tensor * cross_attn_norm; // [hidden]
struct ggml_tensor * ca_qkv; // [hidden, (Nh+2*Nkv)*D] full fused (or NULL)
struct ggml_tensor * ca_q_proj; // separate (always for cross-attn with mixed types)
struct ggml_tensor * ca_kv; // [hidden, 2*Nkv*D] fused KV (or NULL)
struct ggml_tensor * ca_k_proj;
struct ggml_tensor * ca_v_proj;
struct ggml_tensor * ca_q_norm; // [head_dim]
struct ggml_tensor * ca_k_norm; // [head_dim]
struct ggml_tensor * ca_o_proj; // [n_heads*head_dim, hidden]
// MLP
struct ggml_tensor * mlp_norm; // [hidden]
struct ggml_tensor * gate_up; // [hidden, 2*intermediate] fused (or NULL)
struct ggml_tensor * gate_proj; // [hidden, intermediate] (fallback if types differ)
struct ggml_tensor * up_proj; // [hidden, intermediate] (fallback if types differ)
struct ggml_tensor * down_proj; // [intermediate, hidden]
// AdaLN scale-shift table: [6*hidden] (6 rows of [hidden])
struct ggml_tensor * scale_shift_table; // [hidden, 6] in ggml layout
int layer_type; // 0=sliding, 1=full
};
// Full model
#define DIT_GGML_MAX_LAYERS 32
struct DiTGGML {
DiTGGMLConfig cfg;
// Timestep embeddings
DiTGGMLTembWeights time_embed;
DiTGGMLTembWeights time_embed_r;
// proj_in: Conv1d(in_channels, hidden, kernel=2, stride=2)
struct ggml_tensor * proj_in_w; // [in_ch*P, H] pre-permuted F32
struct ggml_tensor * proj_in_b; // [hidden]
// condition_embedder: Linear(encoder_H, decoder_H)
struct ggml_tensor * cond_emb_w; // [encoder_H, decoder_H] projects encoder to decoder space
struct ggml_tensor * cond_emb_b; // [decoder_H]
// Layers
DiTGGMLLayer layers[DIT_GGML_MAX_LAYERS];
// Output
struct ggml_tensor * norm_out; // [hidden]
struct ggml_tensor * out_scale_shift; // [hidden, 2] in ggml layout
struct ggml_tensor * proj_out_w; // [H, out_ch*P] pre-permuted+transposed F32
struct ggml_tensor * proj_out_b; // [out_channels]
// CFG (classifier-free guidance, used by base/sft models)
struct ggml_tensor * null_condition_emb; // [hidden] or NULL if not present
// Backend
ggml_backend_t backend;
ggml_backend_t cpu_backend;
ggml_backend_sched_t sched;
bool use_flash_attn;
// Weight storage
WeightCtx wctx;
// Pre-allocated constant for AdaLN (1+scale) fusion
struct ggml_tensor * scalar_one; // [1] = 1.0f, broadcast in ggml_add
};
// Load timestep embedding weights
static void dit_ggml_load_temb(DiTGGMLTembWeights * w, WeightCtx * wctx,
const GGUFModel & gf, const std::string & prefix) {
w->linear_1_w = gf_load_tensor(wctx, gf, prefix + ".linear_1.weight");
w->linear_1_b = gf_load_tensor_f32(wctx, gf, prefix + ".linear_1.bias");
w->linear_2_w = gf_load_tensor(wctx, gf, prefix + ".linear_2.weight");
w->linear_2_b = gf_load_tensor_f32(wctx, gf, prefix + ".linear_2.bias");
w->time_proj_w = gf_load_tensor(wctx, gf, prefix + ".time_proj.weight");
w->time_proj_b = gf_load_tensor_f32(wctx, gf, prefix + ".time_proj.bias");
}
// Load proj_in weight: GGUF [H, in_ch, P] -> pre-permuted 2D [in_ch*P, H] F32
// Eliminates runtime permute+cont in the compute graph.
static struct ggml_tensor * dit_load_proj_in_w(
WeightCtx * wctx, const GGUFModel & gf, const std::string & name,
int H, int in_ch, int P) {
int64_t idx = gguf_find_tensor(gf.gguf, name.c_str());
if (idx < 0) {
fprintf(stderr, "[GGUF] FATAL: tensor '%s' not found\n", name.c_str());
exit(1);
}
struct ggml_tensor * src = ggml_get_tensor(gf.meta, name.c_str());
if (!src) {
fprintf(stderr, "[GGUF] FATAL: meta tensor '%s' not found\n", name.c_str());
exit(1);
}
size_t offset = gguf_get_tensor_offset(gf.gguf, idx);
const void * raw = gf.mapping + gf.data_offset + offset;
struct ggml_tensor * dst = ggml_new_tensor_2d(wctx->ctx, GGML_TYPE_F32, in_ch * P, H);
ggml_set_name(dst, name.c_str());
size_t n = (size_t)in_ch * P * H;
wctx->staging.emplace_back(n);
auto & buf = wctx->staging.back();
// src ggml [P, in_ch, H]: elem(p, ic, h) = raw[h*P*in_ch + ic*P + p]
// dst ggml [in_ch*P, H]: elem(j, h) = buf[h*in_ch*P + j] where j = p*in_ch + ic
auto cvt = [&](auto read_fn) {
for (int h = 0; h < H; h++)
for (int ic = 0; ic < in_ch; ic++)
for (int p = 0; p < P; p++)
buf[h*in_ch*P + p*in_ch + ic] = read_fn(h*P*in_ch + ic*P + p);
};
if (src->type == GGML_TYPE_BF16) {
const uint16_t * s = (const uint16_t *)raw;
cvt([&](int i) { return ggml_bf16_to_fp32(*(const ggml_bf16_t *)&s[i]); });
} else if (src->type == GGML_TYPE_F16) {
const ggml_fp16_t * s = (const ggml_fp16_t *)raw;
cvt([&](int i) { return ggml_fp16_to_fp32(s[i]); });
} else if (src->type == GGML_TYPE_F32) {
const float * s = (const float *)raw;
cvt([&](int i) { return s[i]; });
} else {
fprintf(stderr, "[GGUF] FATAL: unsupported type %d for '%s' in proj_in pre-permute\n",
src->type, name.c_str());
exit(1);
}
wctx->pending.push_back({dst, buf.data(), n * sizeof(float), 0});
return dst;
}
// Load proj_out weight: GGUF [H, out_ch, P] -> pre-permuted+transposed 2D [H, out_ch*P] F32
// Eliminates runtime permute+cont+transpose+cont in the compute graph.
static struct ggml_tensor * dit_load_proj_out_w(
WeightCtx * wctx, const GGUFModel & gf, const std::string & name,
int H, int out_ch, int P) {
int64_t idx = gguf_find_tensor(gf.gguf, name.c_str());
if (idx < 0) {
fprintf(stderr, "[GGUF] FATAL: tensor '%s' not found\n", name.c_str());
exit(1);
}
struct ggml_tensor * src = ggml_get_tensor(gf.meta, name.c_str());
if (!src) {
fprintf(stderr, "[GGUF] FATAL: meta tensor '%s' not found\n", name.c_str());
exit(1);
}
size_t offset = gguf_get_tensor_offset(gf.gguf, idx);
const void * raw = gf.mapping + gf.data_offset + offset;
struct ggml_tensor * dst = ggml_new_tensor_2d(wctx->ctx, GGML_TYPE_F32, H, out_ch * P);
ggml_set_name(dst, name.c_str());
size_t n = (size_t)out_ch * P * H;
wctx->staging.emplace_back(n);
auto & buf = wctx->staging.back();
// src ggml [P, out_ch, H]: elem(p, oc, h) = raw[h*P*out_ch + oc*P + p]
// dst ggml [H, out_ch*P]: elem(h, j) = buf[j*H + h] where j = p*out_ch + oc
auto cvt = [&](auto read_fn) {
for (int h = 0; h < H; h++)
for (int oc = 0; oc < out_ch; oc++)
for (int p = 0; p < P; p++)
buf[(p*out_ch + oc)*H + h] = read_fn(h*P*out_ch + oc*P + p);
};
if (src->type == GGML_TYPE_BF16) {
const uint16_t * s = (const uint16_t *)raw;
cvt([&](int i) { return ggml_bf16_to_fp32(*(const ggml_bf16_t *)&s[i]); });
} else if (src->type == GGML_TYPE_F16) {
const ggml_fp16_t * s = (const ggml_fp16_t *)raw;
cvt([&](int i) { return ggml_fp16_to_fp32(s[i]); });
} else if (src->type == GGML_TYPE_F32) {
const float * s = (const float *)raw;
cvt([&](int i) { return s[i]; });
} else {
fprintf(stderr, "[GGUF] FATAL: unsupported type %d for '%s' in proj_out pre-permute\n",
src->type, name.c_str());
exit(1);
}
wctx->pending.push_back({dst, buf.data(), n * sizeof(float), 0});
return dst;
}
// Load full DiT model from GGUF
static bool dit_ggml_load(DiTGGML * m, const char * gguf_path, DiTGGMLConfig & cfg) {
GGUFModel gf;
if (!gf_load(&gf, gguf_path)) {
fprintf(stderr, "[Load] FATAL: cannot load %s\n", gguf_path);
return false;
}
// config from GGUF metadata (all keys required)
cfg.n_layers = (int) gf_get_u32(gf, "acestep-dit.block_count");
cfg.hidden_size = (int) gf_get_u32(gf, "acestep-dit.embedding_length");
cfg.intermediate_size = (int) gf_get_u32(gf, "acestep-dit.feed_forward_length");
cfg.n_heads = (int) gf_get_u32(gf, "acestep-dit.attention.head_count");
cfg.n_kv_heads = (int) gf_get_u32(gf, "acestep-dit.attention.head_count_kv");
cfg.head_dim = (int) gf_get_u32(gf, "acestep-dit.attention.key_length");
cfg.in_channels = (int) gf_get_u32(gf, "acestep.in_channels");
cfg.out_channels = (int) gf_get_u32(gf, "acestep.audio_acoustic_hidden_dim");
cfg.patch_size = (int) gf_get_u32(gf, "acestep.patch_size");
cfg.sliding_window = (int) gf_get_u32(gf, "acestep.sliding_window");
cfg.rope_theta = gf_get_f32(gf, "acestep-dit.rope.freq_base");
cfg.rms_norm_eps = gf_get_f32(gf, "acestep-dit.attention.layer_norm_rms_epsilon");
if (!cfg.n_layers || !cfg.hidden_size || !cfg.intermediate_size || !cfg.n_heads || !cfg.n_kv_heads ||
!cfg.head_dim || !cfg.in_channels || !cfg.out_channels || !cfg.patch_size || !cfg.sliding_window ||
cfg.rope_theta <= 0.0f || cfg.rms_norm_eps <= 0.0f) {
fprintf(stderr, "[Load] FATAL: incomplete DiT config in GGUF\n");
gf_close(&gf);
return false;
}
m->cfg = cfg;
// tensor count: temb(6*2) + proj_in(2) + cond_emb(2) + layers(19*N) + output(4) + null_cond(1) + scalar_one(1)
int n_tensors = 6 * 2 + 2 + 2 + 19 * cfg.n_layers + 4 + 1 + 1;
wctx_init(&m->wctx, n_tensors);
// Timestep embeddings
dit_ggml_load_temb(&m->time_embed, &m->wctx, gf, "decoder.time_embed");
dit_ggml_load_temb(&m->time_embed_r, &m->wctx, gf, "decoder.time_embed_r");
// proj_in: Conv1d weight [hidden, in_ch, patch_size]
// Pre-permuted to 2D [in_ch*P, H] F32 at load time
m->proj_in_w = dit_load_proj_in_w(&m->wctx, gf, "decoder.proj_in.1.weight",
cfg.hidden_size, cfg.in_channels, cfg.patch_size);
m->proj_in_b = gf_load_tensor_f32(&m->wctx, gf, "decoder.proj_in.1.bias");
// condition_embedder
m->cond_emb_w = gf_load_tensor(&m->wctx, gf, "decoder.condition_embedder.weight");
m->cond_emb_b = gf_load_tensor_f32(&m->wctx, gf, "decoder.condition_embedder.bias");
// Layers
for (int i = 0; i < cfg.n_layers; i++) {
char prefix[128];
snprintf(prefix, sizeof(prefix), "decoder.layers.%d", i);
std::string p(prefix);
DiTGGMLLayer & ly = m->layers[i];
// Self-attention: try full QKV, partial QK, separate
ly.self_attn_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".self_attn_norm.weight");
ly.sa_qkv = gf_load_qkv_fused(&m->wctx, gf,
p + ".self_attn.q_proj.weight",
p + ".self_attn.k_proj.weight",
p + ".self_attn.v_proj.weight");
if (!ly.sa_qkv) {
// Try Q+K fusion (same input, often same type in K-quants)
ly.sa_qk = gf_load_pair_fused(&m->wctx, gf,
p + ".self_attn.q_proj.weight",
p + ".self_attn.k_proj.weight");
if (ly.sa_qk) {
ly.sa_v_proj = gf_load_tensor(&m->wctx, gf, p + ".self_attn.v_proj.weight");
if (i == 0) fprintf(stderr, "[DiT] Self-attn: Q+K fused, V separate\n");
} else {
ly.sa_q_proj = gf_load_tensor(&m->wctx, gf, p + ".self_attn.q_proj.weight");
ly.sa_k_proj = gf_load_tensor(&m->wctx, gf, p + ".self_attn.k_proj.weight");
ly.sa_v_proj = gf_load_tensor(&m->wctx, gf, p + ".self_attn.v_proj.weight");
if (i == 0) fprintf(stderr, "[DiT] Self-attn: all separate (3 types differ)\n");
}
} else {
if (i == 0) fprintf(stderr, "[DiT] Self-attn: Q+K+V fused\n");
}
ly.sa_q_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".self_attn.q_norm.weight");
ly.sa_k_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".self_attn.k_norm.weight");
ly.sa_o_proj = gf_load_tensor(&m->wctx, gf, p + ".self_attn.o_proj.weight");
// Cross-attention: try full QKV, K+V fused, separate
ly.cross_attn_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".cross_attn_norm.weight");
ly.ca_qkv = gf_load_qkv_fused(&m->wctx, gf,
p + ".cross_attn.q_proj.weight",
p + ".cross_attn.k_proj.weight",
p + ".cross_attn.v_proj.weight");
if (!ly.ca_qkv) {
ly.ca_q_proj = gf_load_tensor(&m->wctx, gf, p + ".cross_attn.q_proj.weight");
// Try K+V fusion (same input enc, may share type)
ly.ca_kv = gf_load_pair_fused(&m->wctx, gf,
p + ".cross_attn.k_proj.weight",
p + ".cross_attn.v_proj.weight");
if (ly.ca_kv) {
if (i == 0) fprintf(stderr, "[DiT] Cross-attn: Q separate, K+V fused\n");
} else {
ly.ca_k_proj = gf_load_tensor(&m->wctx, gf, p + ".cross_attn.k_proj.weight");
ly.ca_v_proj = gf_load_tensor(&m->wctx, gf, p + ".cross_attn.v_proj.weight");
if (i == 0) fprintf(stderr, "[DiT] Cross-attn: all separate\n");
}
} else {
if (i == 0) fprintf(stderr, "[DiT] Cross-attn: Q+K+V fused\n");
}
ly.ca_q_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".cross_attn.q_norm.weight");
ly.ca_k_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".cross_attn.k_norm.weight");
ly.ca_o_proj = gf_load_tensor(&m->wctx, gf, p + ".cross_attn.o_proj.weight");
// MLP: try gate+up fusion (same input, same pattern as QKV)
ly.mlp_norm = gf_load_tensor_f32(&m->wctx, gf, p + ".mlp_norm.weight");
ly.gate_up = gf_load_pair_fused(&m->wctx, gf,
p + ".mlp.gate_proj.weight",
p + ".mlp.up_proj.weight");
if (ly.gate_up) {
if (i == 0) fprintf(stderr, "[DiT] MLP: gate+up fused\n");
} else {
ly.gate_proj = gf_load_tensor(&m->wctx, gf, p + ".mlp.gate_proj.weight");
ly.up_proj = gf_load_tensor(&m->wctx, gf, p + ".mlp.up_proj.weight");
if (i == 0) fprintf(stderr, "[DiT] MLP: gate+up separate (types differ)\n");
}
ly.down_proj = gf_load_tensor(&m->wctx, gf, p + ".mlp.down_proj.weight");
// AdaLN scale_shift_table [1, 6, hidden] in GGUF
ly.scale_shift_table = gf_load_tensor_f32(&m->wctx, gf, p + ".scale_shift_table");
ly.layer_type = (i % 2 == 0) ? 0 : 1; // 0=sliding, 1=full
}
// Output
m->norm_out = gf_load_tensor_f32(&m->wctx, gf, "decoder.norm_out.weight");
m->out_scale_shift = gf_load_tensor_f32(&m->wctx, gf, "decoder.scale_shift_table");
m->proj_out_w = dit_load_proj_out_w(&m->wctx, gf, "decoder.proj_out.1.weight",
cfg.hidden_size, cfg.out_channels, cfg.patch_size);
m->proj_out_b = gf_load_tensor_f32(&m->wctx, gf, "decoder.proj_out.1.bias");
// Null condition embedding for CFG (base/sft models; turbo has it but unused at inference)
m->null_condition_emb = gf_try_load_tensor(&m->wctx, gf, "null_condition_emb");
if (m->null_condition_emb) {
fprintf(stderr, "[Load] null_condition_emb found (CFG available)\n");
}
// Scalar constant for AdaLN (1+scale) fusion
static const float one_val = 1.0f;
m->scalar_one = ggml_new_tensor_1d(m->wctx.ctx, GGML_TYPE_F32, 1);
m->wctx.pending.push_back({m->scalar_one, &one_val, sizeof(float), 0});
// Allocate backend buffer and copy weights
if (!wctx_alloc(&m->wctx, m->backend)) {
gf_close(&gf);
return false;
}
gf_close(&gf);
fprintf(stderr, "[Load] DiT: %d layers, H=%d, Nh=%d/%d, D=%d\n",
cfg.n_layers, cfg.hidden_size, cfg.n_heads, cfg.n_kv_heads, cfg.head_dim);
return true;
}
// Backend init
static void dit_ggml_init_backend(DiTGGML * m) {
BackendPair bp = backend_init("DiT");
m->backend = bp.backend;
m->cpu_backend = bp.cpu_backend;
m->sched = backend_sched_new(bp, 8192);
// flash_attn_ext accumulates in F16 on CPU, causing audible drift over
// 24 layers x 8 steps. Use F32 manual attention on CPU instead.
// m->use_flash_attn = (bp.backend != bp.cpu_backend);
m->use_flash_attn = false; //kcpp: flash attn for music is unstable on vulkan. disable it.
}
// Graph builder: single DiT layer (self-attention block)
// Incremental approach: build and validate one block at a time.
//
// ggml tensor layout reminder:
// [S, H] in math = ne[0]=H, ne[1]=S in ggml
// [Nh, S, D] in math = ne[0]=D, ne[1]=S, ne[2]=Nh in ggml
// Helper: ensure tensor is f32 (cast if bf16/f16)
static struct ggml_tensor * dit_ggml_f32(
struct ggml_context * ctx,
struct ggml_tensor * t) {
if (t->type == GGML_TYPE_F32) return t;
return ggml_cast(ctx, t, GGML_TYPE_F32);
}
// Helper: RMSNorm + weight multiply
static struct ggml_tensor * dit_ggml_rms_norm_weighted(
struct ggml_context * ctx,
struct ggml_tensor * x, // [H, S]
struct ggml_tensor * weight, // [H]
float eps) {
struct ggml_tensor * norm = ggml_rms_norm(ctx, x, eps);
return ggml_mul(ctx, norm, dit_ggml_f32(ctx, weight));
}
// Helper: Linear layer (no bias)
// weight: [in, out] in ggml (= [out, in] in PyTorch)
// input: [in, S]
// output: [out, S]
static struct ggml_tensor * dit_ggml_linear(
struct ggml_context * ctx,
struct ggml_tensor * weight,
struct ggml_tensor * input) {
return ggml_mul_mat(ctx, weight, input);
}
// Helper: Linear layer with bias
static struct ggml_tensor * dit_ggml_linear_bias(
struct ggml_context * ctx,
struct ggml_tensor * weight,
struct ggml_tensor * bias,
struct ggml_tensor * input) {
struct ggml_tensor * out = ggml_mul_mat(ctx, weight, input);
return ggml_add(ctx, out, dit_ggml_f32(ctx, bias));
}
// Helper: AdaLN modulate
// out = norm * (1 + scale) + shift
// norm: [H, S], scale: [H], shift: [H]
static struct ggml_tensor * dit_ggml_adaln(
struct ggml_context * ctx,
struct ggml_tensor * norm,
struct ggml_tensor * scale,
struct ggml_tensor * shift,
struct ggml_tensor * one) {
// norm * (1 + scale) + shift
// one is [1] = 1.0, broadcasts to [H]; avoids expensive [H,S,N] add
struct ggml_tensor * one_plus_s = ggml_add(ctx, scale, one); // [H] + [1] -> [H]
struct ggml_tensor * scaled = ggml_mul(ctx, norm, one_plus_s); // [H,S,N]
return ggml_add(ctx, scaled, shift); // [H,S,N]
}
// Helper: Gated residual
// out = residual + x * gate
// residual: [H, S], x: [H, S], gate: [H]
// NOTE: no sigmoid, gate is a raw scaling factor (matches Python reference)
static struct ggml_tensor * dit_ggml_gated_add(
struct ggml_context * ctx,
struct ggml_tensor * residual,
struct ggml_tensor * x,
struct ggml_tensor * gate) {
struct ggml_tensor * gated = ggml_mul(ctx, x, gate); // broadcast [H] over [H,S]
return ggml_add(ctx, residual, gated);
}
// Build timestep embedding subgraph
// t_scalar: [1] f32, returns temb [H] and *out_tproj [6H]
// suffix: "_t" or "_r" for naming intermediate tensors
static struct ggml_tensor * dit_ggml_build_temb(
struct ggml_context * ctx,
DiTGGMLTembWeights * w,
struct ggml_tensor * t_scalar,
struct ggml_tensor ** out_tproj,
const char * suffix = "") {
// scale timestep by 1000 (diffusion convention, matches Python)
struct ggml_tensor * t_scaled = ggml_scale(ctx, t_scalar, 1000.0f);
// sinusoidal embedding: [1] -> [256]
struct ggml_tensor * sinusoidal = ggml_timestep_embedding(ctx, t_scaled, 256, 10000);
{
char name[64];
snprintf(name, sizeof(name), "sinusoidal%s", suffix);
ggml_set_name(sinusoidal, name);
ggml_set_output(sinusoidal);
}
// linear1 + silu: [256] -> [H]
struct ggml_tensor * h = dit_ggml_linear_bias(ctx, w->linear_1_w, w->linear_1_b, sinusoidal);
{
char name[64];
snprintf(name, sizeof(name), "temb_lin1%s", suffix);
ggml_set_name(h, name);
ggml_set_output(h);
}
h = ggml_silu(ctx, h);
// linear2: [H] -> [H]
struct ggml_tensor * temb = dit_ggml_linear_bias(ctx, w->linear_2_w, w->linear_2_b, h);
// silu + proj: [H] -> [6H]
struct ggml_tensor * h2 = ggml_silu(ctx, temb);
*out_tproj = dit_ggml_linear_bias(ctx, w->time_proj_w, w->time_proj_b, h2);
return temb; // [H] (used for output adaln)
}
// F32 manual attention (fallback when flash_attn_ext is not available or imprecise).
// Q: [D, S, Nh], K: [D, S_kv, Nkv], V: [D, S_kv, Nkv]
// mask: [S_kv, S] F16 or NULL, scale: 1/sqrt(D)
// Returns: [D, Nh, S] (same layout as flash_attn_ext output)
static struct ggml_tensor * dit_attn_f32(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale) {
struct ggml_tensor * scores = ggml_mul_mat(ctx, k, q);
scores = ggml_soft_max_ext(ctx, scores, mask, scale, 0.0f);
struct ggml_tensor * vt = ggml_cont(ctx, ggml_transpose(ctx, v));
struct ggml_tensor * out = ggml_mul_mat(ctx, vt, scores);
return ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3));
}
// Build self-attention sub-graph for a single layer.
// norm_sa: [H, S, N] pre-normalized + AdaLN-modulated hidden state
// Returns: output [H, S, N] (self-attention output, NOT added to residual yet)
static struct ggml_tensor * dit_ggml_build_self_attn(
struct ggml_context * ctx,
DiTGGML * m,
DiTGGMLLayer * ly,
struct ggml_tensor * norm_sa, // [H, S, N] pre-normalized + AdaLN-modulated
struct ggml_tensor * positions, // [S*N] int32 position indices for RoPE
struct ggml_tensor * mask, // [S, S] or NULL (sliding window mask)
int S, int N, int layer_idx = -1) {
DiTGGMLConfig & c = m->cfg;
int D = c.head_dim;
int Nh = c.n_heads;
int Nkv = c.n_kv_heads;
// 1) QKV projections (full fused, QK partial, separate)
struct ggml_tensor * q, * k, * v;
int q_dim = Nh * D;
int kv_dim = Nkv * D;
if (ly->sa_qkv) {
struct ggml_tensor * qkv = dit_ggml_linear(ctx, ly->sa_qkv, norm_sa);
q = ggml_cont(ctx, ggml_view_3d(ctx, qkv, q_dim, S, N, qkv->nb[1], qkv->nb[2], 0));
k = ggml_cont(ctx, ggml_view_3d(ctx, qkv, kv_dim, S, N, qkv->nb[1], qkv->nb[2], (size_t)q_dim * qkv->nb[0]));
v = ggml_cont(ctx, ggml_view_3d(ctx, qkv, kv_dim, S, N, qkv->nb[1], qkv->nb[2], (size_t)(q_dim + kv_dim) * qkv->nb[0]));
} else if (ly->sa_qk) {
struct ggml_tensor * qk = dit_ggml_linear(ctx, ly->sa_qk, norm_sa);
q = ggml_cont(ctx, ggml_view_3d(ctx, qk, q_dim, S, N, qk->nb[1], qk->nb[2], 0));
k = ggml_cont(ctx, ggml_view_3d(ctx, qk, kv_dim, S, N, qk->nb[1], qk->nb[2], (size_t)q_dim * qk->nb[0]));
v = dit_ggml_linear(ctx, ly->sa_v_proj, norm_sa);
} else {
q = dit_ggml_linear(ctx, ly->sa_q_proj, norm_sa);
k = dit_ggml_linear(ctx, ly->sa_k_proj, norm_sa);
v = dit_ggml_linear(ctx, ly->sa_v_proj, norm_sa);
}
// 2) Reshape to heads: [Nh*D, S, N] -> [D, Nh, S, N]
// Rope merges S*N then restores 4D. Permute to flash_attn layout after rope.
q = ggml_reshape_4d(ctx, q, D, Nh, S, N);
k = ggml_reshape_4d(ctx, k, D, Nkv, S, N);
v = ggml_reshape_4d(ctx, v, D, Nkv, S, N);
// 4) QK-Norm: per-head RMSNorm on D dimension
// [D, Nh, S] rms_norm operates on ne[0]=D
q = ggml_rms_norm(ctx, q, c.rms_norm_eps);
q = ggml_mul(ctx, q, dit_ggml_f32(ctx, ly->sa_q_norm));
k = ggml_rms_norm(ctx, k, c.rms_norm_eps);
k = ggml_mul(ctx, k, dit_ggml_f32(ctx, ly->sa_k_norm));
// 5) RoPE (bidirectional, sequential positions)
// ggml_rope_ext asserts ne[2] == positions.ne[0].
// With batch N>1, positions has S*N elements (repeated [0..S-1] per batch).
// Merge S and N before rope, then restore 4D after.
q = ggml_reshape_3d(ctx, q, D, Nh, S * N);
k = ggml_reshape_3d(ctx, k, D, Nkv, S * N);
q = ggml_rope_ext(ctx, q, positions, NULL,
D, 2 /*mode=NEOX*/, 0 /*n_ctx_orig*/,
c.rope_theta, 1.0f /*freq_scale*/,
0.0f, 1.0f, 0.0f, 0.0f);
k = ggml_rope_ext(ctx, k, positions, NULL,
D, 2, 0,
c.rope_theta, 1.0f,
0.0f, 1.0f, 0.0f, 0.0f);
q = ggml_reshape_4d(ctx, q, D, Nh, S, N);
k = ggml_reshape_4d(ctx, k, D, Nkv, S, N);
if (layer_idx == 0) {
ggml_set_name(q, "layer0_q_after_rope");
ggml_set_output(q);
ggml_set_name(k, "layer0_k_after_rope");
ggml_set_output(k);
}
// 6) Permute for flash_attn_ext: [D, Nh, S, N] -> [D, S, Nh, N]
q = ggml_permute(ctx, q, 0, 2, 1, 3);
k = ggml_permute(ctx, k, 0, 2, 1, 3);
v = ggml_permute(ctx, v, 0, 2, 1, 3);
// 7) Attention (flash on GPU, F32 manual on CPU)
// Q[D, S, Nh, N], K[D, S, Nkv, N], V[D, S, Nkv, N]
float scale = 1.0f / sqrtf((float)D);
struct ggml_tensor * attn = m->use_flash_attn
? ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0.0f, 0.0f)
: dit_attn_f32(ctx, q, k, v, mask, scale);
if (m->use_flash_attn) {
ggml_flash_attn_ext_set_prec(attn, GGML_PREC_F32);
}
// Both return [D, Nh, S, N]
// Reshape: [D, Nh, S, N] -> [D*Nh, S, N] = [H, S, N]
attn = ggml_reshape_3d(ctx, attn, Nh * D, S, N);
if (layer_idx == 0) {
ggml_set_name(attn, "layer0_attn_out");
ggml_set_output(attn);
}
// 8) O projection: [Nh*D, S, N] -> [H, S, N]
struct ggml_tensor * out = dit_ggml_linear(ctx, ly->sa_o_proj, attn);
return out;
}
// Build MLP sub-graph: SwiGLU
// norm_ffn: [H, S, N] pre-normalized + AdaLN-modulated hidden state
// Returns: output [H, S, N]
static struct ggml_tensor * dit_ggml_build_mlp(
struct ggml_context * ctx,
DiTGGML * m,
DiTGGMLLayer * ly,
struct ggml_tensor * norm_ffn,
int S) {
struct ggml_tensor * ff;
if (ly->gate_up) {
// Fused: single matmul [H, 2*I] x [H, S, N] -> [2*I, S, N], then swiglu splits ne[0]
struct ggml_tensor * gu = dit_ggml_linear(ctx, ly->gate_up, norm_ffn);
ff = ggml_swiglu(ctx, gu);
} else {
// Separate: two matmuls + split swiglu
struct ggml_tensor * gate = dit_ggml_linear(ctx, ly->gate_proj, norm_ffn);
struct ggml_tensor * up = dit_ggml_linear(ctx, ly->up_proj, norm_ffn);
ff = ggml_swiglu_split(ctx, gate, up);
}
// Down projection: [I, S] -> [H, S]
return dit_ggml_linear(ctx, ly->down_proj, ff);
}
// Build cross-attention sub-graph for a single layer.
// norm_ca: [H, S, N] pre-normalized hidden state (Q source)
// enc: [H, enc_S, N] condition-embedded encoder states (K/V source)
// Returns: output [H, S, N] (NOT added to residual yet)
static struct ggml_tensor * dit_ggml_build_cross_attn(
struct ggml_context * ctx,
DiTGGML * m,
DiTGGMLLayer * ly,
struct ggml_tensor * norm_ca, // [H, S, N]
struct ggml_tensor * enc, // [H, enc_S, N]
struct ggml_tensor * positions, // unused, kept for consistency
int S, int enc_S, int N) {
DiTGGMLConfig & c = m->cfg;
int D = c.head_dim;
int Nh = c.n_heads;
int Nkv = c.n_kv_heads;
(void)positions; // cross-attn has no RoPE
// Q from hidden, KV from encoder (full fused, Q+KV partial, separate)
int q_dim = Nh * D;
int kv_dim = Nkv * D;
struct ggml_tensor * q, * k, * v;
if (ly->ca_qkv) {
// Full QKV fused: split Q from hidden, KV from enc via weight views
struct ggml_tensor * w_q = ggml_view_2d(ctx, ly->ca_qkv, ly->ca_qkv->ne[0], q_dim,
ly->ca_qkv->nb[1], 0);
struct ggml_tensor * w_kv = ggml_view_2d(ctx, ly->ca_qkv, ly->ca_qkv->ne[0], 2 * kv_dim,
ly->ca_qkv->nb[1], (size_t)q_dim * ly->ca_qkv->nb[1]);
q = ggml_mul_mat(ctx, w_q, norm_ca);
struct ggml_tensor * kv = ggml_mul_mat(ctx, w_kv, enc);
k = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], 0));
v = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], (size_t)kv_dim * kv->nb[0]));
} else if (ly->ca_kv) {
// Q separate, K+V fused
q = dit_ggml_linear(ctx, ly->ca_q_proj, norm_ca);
struct ggml_tensor * kv = ggml_mul_mat(ctx, ly->ca_kv, enc);
k = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], 0));
v = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], (size_t)kv_dim * kv->nb[0]));
} else {
q = dit_ggml_linear(ctx, ly->ca_q_proj, norm_ca);
k = dit_ggml_linear(ctx, ly->ca_k_proj, enc);
v = dit_ggml_linear(ctx, ly->ca_v_proj, enc);
}
// reshape to [D, heads, seq, N] then permute to [D, seq, heads, N]
q = ggml_reshape_4d(ctx, q, D, Nh, S, N);
q = ggml_permute(ctx, q, 0, 2, 1, 3); // [D, S, Nh, N]
k = ggml_reshape_4d(ctx, k, D, Nkv, enc_S, N);
k = ggml_permute(ctx, k, 0, 2, 1, 3); // [D, enc_S, Nkv, N]
v = ggml_reshape_4d(ctx, v, D, Nkv, enc_S, N);
v = ggml_permute(ctx, v, 0, 2, 1, 3); // [D, enc_S, Nkv, N]
// QK-norm (per head)
q = ggml_rms_norm(ctx, q, c.rms_norm_eps);
q = ggml_mul(ctx, q, dit_ggml_f32(ctx, ly->ca_q_norm));
k = ggml_rms_norm(ctx, k, c.rms_norm_eps);
k = ggml_mul(ctx, k, dit_ggml_f32(ctx, ly->ca_k_norm));
// no RoPE for cross-attention
// no mask (attend to all encoder positions)
float scale = 1.0f / sqrtf((float)D);
struct ggml_tensor * attn = m->use_flash_attn
? ggml_flash_attn_ext(ctx, q, k, v, NULL, scale, 0.0f, 0.0f)
: dit_attn_f32(ctx, q, k, v, NULL, scale);
if (m->use_flash_attn) {
ggml_flash_attn_ext_set_prec(attn, GGML_PREC_F32);
}
// Attention output: [D, Nh, S, N], reshape to [H, S, N]
attn = ggml_reshape_3d(ctx, attn, Nh * D, S, N);
// O projection
return dit_ggml_linear(ctx, ly->ca_o_proj, attn);
}
// Build one full DiT layer (AdaLN + self-attn + cross-attn + FFN + gated residuals)
// hidden: [H, S, N], tproj: [6H] (combined timestep projection)
// enc: [H, enc_S, N] (condition-embedded encoder states, or NULL to skip cross-attn)
// sw_mask: [S, S] sliding window mask (or NULL for full attention)
// Returns: updated hidden [H, S, N]
static struct ggml_tensor * dit_ggml_build_layer(
struct ggml_context * ctx,
DiTGGML * m,
int layer_idx,
struct ggml_tensor * hidden, // [H, S, N]
struct ggml_tensor * tproj, // [6H] f32 combined temb projection
struct ggml_tensor * enc, // [H, enc_S, N] or NULL
struct ggml_tensor * positions, // [S] int32
struct ggml_tensor * sw_mask, // [S, S] or NULL
int S, int enc_S, int N) {
DiTGGMLConfig & c = m->cfg;
DiTGGMLLayer * ly = &m->layers[layer_idx];
int H = c.hidden_size;
// AdaLN: scale_shift_table [6, H] + tproj [6H] -> 6 vectors of [H]
// scale_shift_table is stored as bf16, cast to f32 for arithmetic
struct ggml_tensor * ss = ly->scale_shift_table;
if (ss->type != GGML_TYPE_F32) {
ss = ggml_cast(ctx, ss, GGML_TYPE_F32);
}
// flatten [H, 6] -> [6H] (ggml ne[0]=H, ne[1]=6, contiguous = 6H floats)
struct ggml_tensor * ss_flat = ggml_reshape_1d(ctx, ss, 6 * H);
struct ggml_tensor * adaln = ggml_add(ctx, ss_flat, tproj); // [6H] f32
// extract 6 modulation vectors [H] each
size_t Hb = H * sizeof(float);
struct ggml_tensor * shift_sa = ggml_view_1d(ctx, adaln, H, 0 * Hb);
struct ggml_tensor * scale_sa = ggml_view_1d(ctx, adaln, H, 1 * Hb);
struct ggml_tensor * gate_sa = ggml_view_1d(ctx, adaln, H, 2 * Hb);
struct ggml_tensor * shift_ffn = ggml_view_1d(ctx, adaln, H, 3 * Hb);
struct ggml_tensor * scale_ffn = ggml_view_1d(ctx, adaln, H, 4 * Hb);
struct ggml_tensor * gate_ffn = ggml_view_1d(ctx, adaln, H, 5 * Hb);
// Self-attention with AdaLN + gated residual
struct ggml_tensor * residual = hidden;
struct ggml_tensor * norm_sa = dit_ggml_rms_norm_weighted(ctx, hidden, ly->self_attn_norm, c.rms_norm_eps);
norm_sa = dit_ggml_adaln(ctx, norm_sa, scale_sa, shift_sa, m->scalar_one);
if (layer_idx == 0) {
ggml_set_name(norm_sa, "layer0_sa_input");
ggml_set_output(norm_sa);
}
// select mask: even layers use sliding window, odd layers use full attention
struct ggml_tensor * mask = (ly->layer_type == 0) ? sw_mask : NULL;
struct ggml_tensor * sa_out = dit_ggml_build_self_attn(ctx, m, ly, norm_sa, positions, mask, S, N, layer_idx);
if (layer_idx == 0) {
ggml_set_name(sa_out, "layer0_sa_output");
ggml_set_output(sa_out);
}
hidden = dit_ggml_gated_add(ctx, residual, sa_out, gate_sa);
if (layer_idx == 0) {
ggml_set_name(hidden, "layer0_after_self_attn");
ggml_set_output(hidden);
}
// Cross-attention (no gate, simple residual add)
if (enc) {
struct ggml_tensor * norm_ca = dit_ggml_rms_norm_weighted(ctx, hidden, ly->cross_attn_norm, c.rms_norm_eps);
struct ggml_tensor * ca_out = dit_ggml_build_cross_attn(ctx, m, ly, norm_ca, enc, positions, S, enc_S, N);
hidden = ggml_add(ctx, hidden, ca_out);
}
if (layer_idx == 0) {
ggml_set_name(hidden, "layer0_after_cross_attn");
ggml_set_output(hidden);
}
// FFN with AdaLN + gated residual
residual = hidden;
struct ggml_tensor * norm_ffn = dit_ggml_rms_norm_weighted(ctx, hidden, ly->mlp_norm, c.rms_norm_eps);
norm_ffn = dit_ggml_adaln(ctx, norm_ffn, scale_ffn, shift_ffn, m->scalar_one);
struct ggml_tensor * ffn_out = dit_ggml_build_mlp(ctx, m, ly, norm_ffn, S);
hidden = dit_ggml_gated_add(ctx, residual, ffn_out, gate_ffn);
return hidden;
}
// Build the full DiT forward graph (all layers).
// Returns the final output tensor (velocity prediction).
// N = batch size (number of samples to denoise in parallel).
//
// Graph inputs (ggml [ne0, ne1, ne2] notation):
// "input_latents" [in_channels, T, N] concat(context_latents, xt) per sample
// "enc_hidden" [H, enc_S, N] text encoder hidden states (N copies)
// "t" [1] f32 flow matching timestep (shared)
// "t_r" [1] f32 reference timestep (shared)
// "positions" [S*N] i32 position indices 0..S-1 repeated N times
// "sw_mask" [S, S, 1, N] f16 sliding window mask (N identical copies)
//
// Graph outputs:
// "velocity" [out_channels, T, N] predicted flow velocity
static struct ggml_cgraph * dit_ggml_build_graph(
DiTGGML * m,
struct ggml_context * ctx,
int T, // temporal length (before patching)
int enc_S, // encoder sequence length
int N, // batch size
struct ggml_tensor ** p_input, // [out] input tensor to fill
struct ggml_tensor ** p_output) { // [out] output tensor to read
DiTGGMLConfig & c = m->cfg;
int S = T / c.patch_size; // sequence length after patching
int H = c.hidden_size;
int P = c.patch_size;
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 8192, false);
// Inputs
// Concatenated latent: [in_channels, T, N] per sample
struct ggml_tensor * input = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, c.in_channels, T, N);
ggml_set_name(input, "input_latents");
ggml_set_input(input);
*p_input = input;
// Encoder hidden states: [H_enc, enc_S, N]
// H_enc comes from the condition_embedder input dimension (2048 for both 2B and XL).
// The condition_embedder projects H_enc -> H (decoder) via cond_emb_w.
int H_enc = (int) m->cond_emb_w->ne[0];
struct ggml_tensor * enc_hidden = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, H_enc, enc_S, N);
ggml_set_name(enc_hidden, "enc_hidden");
ggml_set_input(enc_hidden);
// Timesteps: scalars
struct ggml_tensor * t_val = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name(t_val, "t");
ggml_set_input(t_val);
struct ggml_tensor * tr_val = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name(tr_val, "t_r");
ggml_set_input(tr_val);
// Position indices for RoPE: [N*S] with values [0..S-1] repeated N times.
// The CUDA rope kernel indexes positions by channel_x = row / ne1 which
// linearizes (ne2, ne3) = (S, N). Batch b reads pos[b*S + s], so we must
// repeat the sequence for each batch element.
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S * N);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
// ggml pitfall: flash_attn_ext reads mask as fp16!
// Must be 4D [S, S, 1, N] not 2D [S, S]: the CUDA flash_attn_mask_to_KV_max
// optimization kernel offsets the mask pointer by sequence*nb[3] per batch element.
// With 2D mask (ne[3]=1), batch 1+ reads out of bounds. Replicate mask N times.
struct ggml_tensor * sw_mask = NULL;
if (c.sliding_window > 0 && S > c.sliding_window) {
sw_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, S, S, 1, N);
ggml_set_name(sw_mask, "sw_mask");
ggml_set_input(sw_mask);
}
// 1) Timestep embeddings
struct ggml_tensor * tproj;
struct ggml_tensor * temb;
{
struct ggml_tensor * tproj_t;
struct ggml_tensor * temb_t = dit_ggml_build_temb(ctx, &m->time_embed, t_val, &tproj_t, "_t");
ggml_set_name(temb_t, "temb_t");
ggml_set_output(temb_t);
struct ggml_tensor * tproj_r;
// Python passes (t - t_r) to time_embed_r, not t_r directly
// In turbo mode t = t_r, so input is 0
struct ggml_tensor * t_diff = ggml_sub(ctx, t_val, tr_val);
struct ggml_tensor * temb_r = dit_ggml_build_temb(ctx, &m->time_embed_r, t_diff, &tproj_r, "_r");
ggml_set_name(temb_r, "temb_r");
ggml_set_output(temb_r);
// combine: temb = temb_t + temb_r [H], tproj = tproj_t + tproj_r [6H]
temb = ggml_add(ctx, temb_t, temb_r);
ggml_set_name(temb, "temb");
ggml_set_output(temb);
tproj = ggml_add(ctx, tproj_t, tproj_r);
ggml_set_name(tproj, "tproj");
ggml_set_output(tproj);
}
// 2) proj_in: patchify + linear (weight pre-permuted at load time)
ggml_set_name(input, "proj_in_input");
ggml_set_output(input);
struct ggml_tensor * patched = ggml_reshape_3d(ctx, input, c.in_channels * P, S, N);
struct ggml_tensor * hidden = dit_ggml_linear_bias(ctx, m->proj_in_w, m->proj_in_b, patched);
ggml_set_name(hidden, "hidden_after_proj_in");
ggml_set_output(hidden);
// 3) Condition embedder: project encoder hidden states
struct ggml_tensor * enc = dit_ggml_linear_bias(ctx, m->cond_emb_w, m->cond_emb_b, enc_hidden);
ggml_set_name(enc, "enc_after_cond_emb");
ggml_set_output(enc);
// 4) Transformer layers
for (int i = 0; i < c.n_layers; i++) {
hidden = dit_ggml_build_layer(ctx, m, i, hidden, tproj, enc, positions, sw_mask, S, enc_S, N);
// Debug dumps at key layers: 0, 6, 12, 18, 23
if (i == 0 || i == 6 || i == 12 || i == 18 || i == c.n_layers - 1) {
char lname[64];
snprintf(lname, sizeof(lname), "hidden_after_layer%d", i);
ggml_set_name(hidden, lname);
ggml_set_output(hidden);
}
}
// 5) Output: AdaLN + proj_out
// out_scale_shift: [H, 2] -> cast to f32 if bf16, flatten to [2H]
struct ggml_tensor * oss = m->out_scale_shift;
if (oss->type != GGML_TYPE_F32) {
oss = ggml_cast(ctx, oss, GGML_TYPE_F32);
}
struct ggml_tensor * oss_flat = ggml_reshape_1d(ctx, oss, 2 * H);
size_t Hb = H * sizeof(float);
struct ggml_tensor * out_shift = ggml_view_1d(ctx, oss_flat, H, 0);
struct ggml_tensor * out_scale = ggml_view_1d(ctx, oss_flat, H, Hb);
out_shift = ggml_add(ctx, out_shift, temb);
out_scale = ggml_add(ctx, out_scale, temb);
struct ggml_tensor * norm_out = dit_ggml_rms_norm_weighted(ctx, hidden, m->norm_out, c.rms_norm_eps);
norm_out = dit_ggml_adaln(ctx, norm_out, out_scale, out_shift, m->scalar_one);
// proj_out: weight pre-permuted+transposed at load time to [H, out_ch*P] F32
struct ggml_tensor * output = dit_ggml_linear_bias(ctx, m->proj_out_w, m->proj_out_b, norm_out);
output = ggml_reshape_3d(ctx, output, c.out_channels, T, N);
ggml_set_name(output, "velocity");
ggml_set_output(output);
*p_output = output;
ggml_build_forward_expand(gf, output);
return gf;
}
// APG (Adaptive Projected Guidance) for DiT CFG
// Matches Python ACE-Step-1.5 acestep/models/base/apg_guidance.py
struct APGMomentumBuffer {
double momentum;
std::vector<double> running_average;
bool initialized;
APGMomentumBuffer(double m = -0.75) : momentum(m), initialized(false) {}
void update(const double * values, int n) {
if (!initialized) {
running_average.assign(values, values + n);
initialized = true;
} else {
for (int i = 0; i < n; i++)
running_average[i] = values[i] + momentum * running_average[i];
}
}
};
// project(v0, v1, dims=[1]): decompose v0 into parallel + orthogonal w.r.t. v1
// All math in double precision matching Python .double() calls.
// Layout: memory [T, Oc] time-major (ggml ne=[Oc, T]).
// Python dims=[1] on [B,T,C] = normalize/project per channel over T dimension.
// In memory [T, Oc] layout: for each channel c, operate over all T time frames.
static void apg_project(
const double * v0, const double * v1,
double * out_par, double * out_orth,
int Oc, int T) {
for (int c = 0; c < Oc; c++) {
double norm2 = 0.0;
for (int t = 0; t < T; t++)
norm2 += v1[t * Oc + c] * v1[t * Oc + c];
double inv_norm = (norm2 > 1e-60) ? (1.0 / sqrt(norm2)) : 0.0;
double dot = 0.0;
for (int t = 0; t < T; t++)
dot += v0[t * Oc + c] * (v1[t * Oc + c] * inv_norm);
for (int t = 0; t < T; t++) {
int idx = t * Oc + c;
double v1n = v1[idx] * inv_norm;
out_par[idx] = dot * v1n;
out_orth[idx] = v0[idx] - out_par[idx];
}
}
}
// APG forward matching Python apg_forward() exactly:
// 1. diff = cond - uncond
// 2. momentum.update(diff); diff = running_average
// 3. norm clip: per-channel L2 over T (dims=[1]), clip to norm_threshold=2.5
// 4. project(diff, pred_COND) -> (parallel, orthogonal)
// 5. result = pred_cond + (scale - 1) * orthogonal
// Internal computation in double precision (Python uses .double()).
static void apg_forward(
const float * pred_cond, const float * pred_uncond,
float guidance_scale, APGMomentumBuffer & mbuf,
float * result, int Oc, int T,
float norm_threshold = 2.5f) {
int n = Oc * T;
// 1. diff = cond - uncond (promote to double)
std::vector<double> diff(n);
for (int i = 0; i < n; i++)
diff[i] = (double)pred_cond[i] - (double)pred_uncond[i];
// 2. momentum update, then use smoothed diff
mbuf.update(diff.data(), n);
memcpy(diff.data(), mbuf.running_average.data(), n * sizeof(double));
// 3. norm clipping: per-channel L2 over T (dims=[1]), clip to threshold
if (norm_threshold > 0.0f) {
for (int c = 0; c < Oc; c++) {
double norm2 = 0.0;
for (int t = 0; t < T; t++)
norm2 += diff[t * Oc + c] * diff[t * Oc + c];
double norm = sqrt(norm2 > 0.0 ? norm2 : 0.0);
double s = (norm > 1e-60) ? fmin(1.0, (double)norm_threshold / norm) : 1.0;
if (s < 1.0) {
for (int t = 0; t < T; t++)
diff[t * Oc + c] *= s;
}
}
}
// 4. project(diff, pred_COND) -> orthogonal component (double precision)
std::vector<double> pred_cond_d(n), par(n), orth(n);
for (int i = 0; i < n; i++) pred_cond_d[i] = (double)pred_cond[i];
apg_project(diff.data(), pred_cond_d.data(), par.data(), orth.data(), Oc, T);
// 5. result = pred_cond + (scale - 1) * orthogonal (back to float)
double w = (double)guidance_scale - 1.0;
for (int i = 0; i < n; i++)
result[i] = (float)((double)pred_cond[i] + w * orth[i]);
}
// Flow matching generation loop (batched)
// Runs num_steps euler steps to denoise N latent samples in parallel.
//
// noise: [N * T * Oc] N contiguous [T, Oc] noise blocks
// context_latents: [N * T * ctx_ch] N contiguous context blocks
// enc_hidden: [enc_S * H_enc * N] per-batch encoder outputs (caller-stacked)
// schedule: array of num_steps timestep values
// output: [N * T * Oc] generated latents (caller-allocated)
static void dit_ggml_generate(
DiTGGML * model,
const float * noise,
const float * context_latents,
const float * enc_hidden_data,
int enc_S,
int T,
int N,
int num_steps,
const float * schedule,
float * output,
float guidance_scale = 1.0f,
const DebugDumper * dbg = nullptr,
const float * context_switch = nullptr,
int cover_steps = -1) {
DiTGGMLConfig & c = model->cfg;
int Oc = c.out_channels; // 64
int ctx_ch = c.in_channels - Oc; // 128
int in_ch = c.in_channels; // 192
int S = T / c.patch_size;
int n_per = T * Oc; // elements per sample
int n_total = N * n_per; // total output elements
fprintf(stderr, "[DiT] Batch N=%d, T=%d, S=%d, enc_S=%d\n", N, T, S, enc_S);
// Graph context (generous fixed allocation, shapes are constant across steps)
size_t ctx_size = ggml_tensor_overhead() * 8192 + ggml_graph_overhead_custom(8192, false);
std::vector<uint8_t> ctx_buf(ctx_size);
struct ggml_init_params gparams = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ ctx_buf.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx = ggml_init(gparams);
struct ggml_tensor * t_input = NULL;
struct ggml_tensor * t_output = NULL;
struct ggml_cgraph * gf = dit_ggml_build_graph(model, ctx, T, enc_S, N,
&t_input, &t_output);
fprintf(stderr, "[DiT] Graph: %d nodes\n", ggml_graph_n_nodes(gf));
struct ggml_tensor * t_enc = ggml_graph_get_tensor(gf, "enc_hidden");
int H_enc = (int) t_enc->ne[0]; // encoder hidden size (from condition_embedder)
// Allocate compute buffers.
// Critical: reset FIRST (clears old state), THEN force inputs to GPU, THEN alloc.
// Without GPU forcing, inputs default to CPU where the scheduler aliases their
// buffers with intermediates. enc_hidden is read at every cross-attn layer (24x),
// so CPU aliasing corrupts it mid-graph. With N>1 the larger buffers trigger
// more aggressive aliasing, causing batch sample 1+ to produce noise.
ggml_backend_sched_reset(model->sched);
if (model->backend != model->cpu_backend) {
const char * input_names[] = {"enc_hidden", "input_latents", "t", "t_r", "positions", "sw_mask"};
for (const char * iname : input_names) {
struct ggml_tensor * t = ggml_graph_get_tensor(gf, iname);
if (t) ggml_backend_sched_set_tensor_backend(model->sched, t, model->backend);
}
}
if (!ggml_backend_sched_alloc_graph(model->sched, gf)) {
fprintf(stderr, "FATAL: failed to allocate graph\n");
ggml_free(ctx);
return;
}
// Encoder hidden states: upload once (re-uploaded per step only when CFG swaps to null)
// t_enc was declared above for backend forcing
// t_r is set per-step in the loop (= t_curr, same as Python reference)
struct ggml_tensor * t_tr = ggml_graph_get_tensor(gf, "t_r");
// Positions: [0, 1, ..., S-1] repeated N times for batch rope indexing
struct ggml_tensor * t_pos = ggml_graph_get_tensor(gf, "positions");
std::vector<int32_t> pos_data(S * N);
for (int b = 0; b < N; b++)
for (int i = 0; i < S; i++)
pos_data[b * S + i] = i;
ggml_backend_tensor_set(t_pos, pos_data.data(), 0, S * N * sizeof(int32_t));
// Sliding window mask: [S, S, 1, N] fp16 - N identical copies
struct ggml_tensor * t_mask = ggml_graph_get_tensor(gf, "sw_mask");
std::vector<uint16_t> mask_data;
if (t_mask) {
int win = c.sliding_window;
mask_data.resize(S * S * N);
// fill first copy
for (int qi = 0; qi < S; qi++)
for (int ki = 0; ki < S; ki++) {
int dist = (qi > ki) ? (qi - ki) : (ki - qi);
float v = (dist <= win) ? 0.0f : -INFINITY;
mask_data[ki * S + qi] = ggml_fp32_to_fp16(v);
}
// replicate for batch elements 1..N-1
for (int b = 1; b < N; b++)
memcpy(mask_data.data() + b * S * S, mask_data.data(), S * S * sizeof(uint16_t));
ggml_backend_tensor_set(t_mask, mask_data.data(), 0, S * S * N * sizeof(uint16_t));
}
// CFG setup
bool do_cfg = guidance_scale > 1.0f;
std::vector<float> null_enc_buf;
std::vector<APGMomentumBuffer> apg_mbufs;
if (do_cfg) {
if (!model->null_condition_emb) {
fprintf(stderr, "[DiT] WARNING: guidance_scale=%.1f but null_condition_emb not found. Disabling CFG.\n", guidance_scale);
do_cfg = false;
} else {
int emb_n = (int)ggml_nelements(model->null_condition_emb);
std::vector<float> null_emb(emb_n);
if (model->null_condition_emb->type == GGML_TYPE_BF16) {
std::vector<uint16_t> bf16_buf(emb_n);
ggml_backend_tensor_get(model->null_condition_emb, bf16_buf.data(), 0, emb_n * sizeof(uint16_t));
for (int i = 0; i < emb_n; i++) {
uint32_t w = (uint32_t)bf16_buf[i] << 16;
memcpy(&null_emb[i], &w, 4);
}
} else {
ggml_backend_tensor_get(model->null_condition_emb, null_emb.data(), 0, emb_n * sizeof(float));
}
// Broadcast [H_enc] to [enc_S, H_enc] then to N copies [H_enc, enc_S, N]
std::vector<float> null_enc_single(H_enc * enc_S);
for (int s = 0; s < enc_S; s++) {
memcpy(&null_enc_single[s * H_enc], null_emb.data(), H_enc * sizeof(float));
}
null_enc_buf.resize(H_enc * enc_S * N);
for (int b = 0; b < N; b++) {
memcpy(null_enc_buf.data() + b * enc_S * H_enc, null_enc_single.data(), enc_S * H_enc * sizeof(float));
}
if (dbg && dbg->enabled) {
debug_dump_1d(dbg, "null_condition_emb", null_emb.data(), emb_n);
debug_dump_2d(dbg, "null_enc_hidden", null_enc_single.data(), enc_S, H_enc);
}
apg_mbufs.resize(N);
fprintf(stderr, "[DiT] CFG enabled: guidance_scale=%.1f, 2x forward per step, N=%d\n", guidance_scale, N);
}
}
// Prepare host buffers (all N samples contiguous)
std::vector<float> xt(noise, noise + n_total);
std::vector<float> vt(n_total);
std::vector<float> vt_cond;
std::vector<float> vt_uncond;
if (do_cfg) {
vt_cond.resize(n_total);
vt_uncond.resize(n_total);
}
// input_buf: [in_ch, T, N] - pre-fill context_latents (constant across all steps)
std::vector<float> input_buf(in_ch * T * N);
for (int b = 0; b < N; b++)
for (int t = 0; t < T; t++)
memcpy(&input_buf[b * T * in_ch + t * in_ch],
&context_latents[b * T * ctx_ch + t * ctx_ch],
ctx_ch * sizeof(float));
// Pre-allocate enc_buf once (avoids heap alloc per step)
std::vector<float> enc_buf(H_enc * enc_S * N);
for (int b = 0; b < N; b++)
memcpy(enc_buf.data() + b * enc_S * H_enc, enc_hidden_data, enc_S * H_enc * sizeof(float));
ggml_backend_tensor_set(t_enc, enc_buf.data(), 0, enc_buf.size() * sizeof(float));
struct ggml_tensor * t_t = ggml_graph_get_tensor(gf, "t");
// Flow matching loop
bool switched_cover = false;
for (int step = 0; step < num_steps; step++) {
float t_curr = schedule[step];
// Cover mode: switch context from cover to non-cover at cover_steps
if (context_switch && cover_steps >= 0 && step >= cover_steps && !switched_cover) {
switched_cover = true;
for (int b = 0; b < N; b++)
for (int t = 0; t < T; t++)
memcpy(&input_buf[b * T * in_ch + t * in_ch],
&context_switch[b * T * ctx_ch + t * ctx_ch],
ctx_ch * sizeof(float));
fprintf(stderr, "[DiT] Cover: switched to non-cover context at step %d/%d\n",
step, num_steps);
}
// Set timestep (changes each step)
if (t_t) {
ggml_backend_tensor_set(t_t, &t_curr, 0, sizeof(float));
}
if (t_tr) {
ggml_backend_tensor_set(t_tr, &t_curr, 0, sizeof(float));
}
// Re-upload constants (scheduler may reuse input buffers as scratch between computes)
ggml_backend_tensor_set(t_enc, enc_buf.data(), 0, enc_buf.size() * sizeof(float));
ggml_backend_tensor_set(t_pos, pos_data.data(), 0, S * N * sizeof(int32_t));
if (t_mask) ggml_backend_tensor_set(t_mask, mask_data.data(), 0, S * S * N * sizeof(uint16_t));
// Update xt portion of input: [in_ch, T, N] (context_latents pre-filled)
for (int b = 0; b < N; b++)
for (int t = 0; t < T; t++)
memcpy(&input_buf[b * T * in_ch + t * in_ch + ctx_ch],
&xt[b * n_per + t * Oc],
Oc * sizeof(float));
ggml_backend_tensor_set(t_input, input_buf.data(), 0, in_ch * T * N * sizeof(float));
// compute forward pass (conditional)
ggml_backend_sched_graph_compute(model->sched, gf);
// dump intermediate tensors on step 0 (sample 0 only for batch)
if (step == 0 && dbg && dbg->enabled) {
auto dump_named = [&](const char *name) {
struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
if (t) {
// For batched tensors, dump only sample 0 (first slice)
int64_t n0 = t->ne[0];
int64_t n1 = t->ne[1];
int64_t sample_elems = n0 * n1; // [ne0, ne1] of first sample
std::vector<float> buf(sample_elems);
ggml_backend_tensor_get(t, buf.data(), 0, sample_elems * sizeof(float));
if (n1 <= 1) {
debug_dump_1d(dbg, name, buf.data(), (int)n0);
} else {
debug_dump_2d(dbg, name, buf.data(), (int)n0, (int)n1);
}
}
};
dump_named("tproj");
dump_named("temb");
dump_named("temb_t");
dump_named("temb_r");
dump_named("sinusoidal_t");
dump_named("sinusoidal_r");
dump_named("temb_lin1_t");
dump_named("temb_lin1_r");
dump_named("hidden_after_proj_in");
dump_named("proj_in_input");
dump_named("enc_after_cond_emb");
dump_named("layer0_sa_input");
dump_named("layer0_q_after_rope");
dump_named("layer0_k_after_rope");
dump_named("layer0_sa_output");
dump_named("layer0_attn_out");
dump_named("layer0_after_self_attn");
dump_named("layer0_after_cross_attn");
dump_named("hidden_after_layer0");
dump_named("hidden_after_layer6");
dump_named("hidden_after_layer12");
dump_named("hidden_after_layer18");
dump_named("hidden_after_layer23");
}
// read velocity output: [Oc, T, N]
ggml_backend_tensor_get(t_output, vt.data(), 0, n_total * sizeof(float));
// CFG: unconditional pass + APG per sample
if (do_cfg) {
memcpy(vt_cond.data(), vt.data(), n_total * sizeof(float));
if (dbg && dbg->enabled) {
char name[64];
snprintf(name, sizeof(name), "dit_step%d_vt_cond", step);
debug_dump_2d(dbg, name, vt_cond.data(), T, Oc);
}
// Unconditional pass: re-upload all inputs (scheduler clobbers input buffers during compute)
ggml_backend_tensor_set(t_enc, null_enc_buf.data(), 0, H_enc * enc_S * N * sizeof(float));
ggml_backend_tensor_set(t_input, input_buf.data(), 0, in_ch * T * N * sizeof(float));
if (t_t) ggml_backend_tensor_set(t_t, &t_curr, 0, sizeof(float));
if (t_tr) ggml_backend_tensor_set(t_tr, &t_curr, 0, sizeof(float));
ggml_backend_tensor_set(t_pos, pos_data.data(), 0, S * N * sizeof(int32_t));
if (t_mask) ggml_backend_tensor_set(t_mask, mask_data.data(), 0, S * S * N * sizeof(uint16_t));
ggml_backend_sched_graph_compute(model->sched, gf);
ggml_backend_tensor_get(t_output, vt_uncond.data(), 0, n_total * sizeof(float));
if (dbg && dbg->enabled) {
char name[64];
snprintf(name, sizeof(name), "dit_step%d_vt_uncond", step);
debug_dump_2d(dbg, name, vt_uncond.data(), T, Oc);
}
// APG per sample
for (int b = 0; b < N; b++) {
apg_forward(vt_cond.data() + b * n_per,
vt_uncond.data() + b * n_per,
guidance_scale, apg_mbufs[b],
vt.data() + b * n_per, Oc, T);
}
}
if (dbg && dbg->enabled) {
char name[64];
snprintf(name, sizeof(name), "dit_step%d_vt", step);
debug_dump_2d(dbg, name, vt.data(), T, Oc);
}
// euler step (all N samples)
if (step == num_steps - 1) {
for (int i = 0; i < n_total; i++)
output[i] = xt[i] - vt[i] * t_curr;
} else {
float dt = t_curr - schedule[step + 1];
for (int i = 0; i < n_total; i++)
xt[i] -= vt[i] * dt;
}
// debug dump (sample 0 only)
if (dbg && dbg->enabled) {
char name[64];
if (step == num_steps - 1) {
snprintf(name, sizeof(name), "dit_x0");
debug_dump_2d(dbg, name, output, T, Oc);
} else {
snprintf(name, sizeof(name), "dit_step%d_xt", step);
debug_dump_2d(dbg, name, xt.data(), T, Oc);
}
}
fprintf(stderr, "[DiT] step %d/%d t=%.3f\n", step + 1, num_steps, t_curr);
}
// Batch diagnostic: report per-sample stats to catch corruption
if (N >= 2) {
for (int b = 0; b < N; b++) {
const float * s = output + b * n_per;
float mn = s[0], mx = s[0], sum = 0.0f;
int n_nan = 0;
for (int i = 0; i < n_per; i++) {
float v = s[i];
if (v != v) { n_nan++; continue; }
if (v < mn) mn = v;
if (v > mx) mx = v;
sum += v;
}
fprintf(stderr, "[DiT] Batch%d output: min=%.4f max=%.4f mean=%.6f nan=%d\n",
b, mn, mx, sum / (float)n_per, n_nan);
}
}
ggml_free(ctx);
}
// Free
static void dit_ggml_free(DiTGGML * m) {
if (m->sched) ggml_backend_sched_free(m->sched);
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
wctx_free(&m->wctx);
*m = {};
}