koboldcpp/otherarch/acestep/cond.h

376 lines
No EOL
15 KiB
C++

// cond.h: ACEStep Condition Encoder via ggml
//
// Produces encoder_hidden_states [S_total, 2048] from (all arrays H-contiguous per token):
// - text_hidden [S_text, 1024] : from Qwen3-Embedding text encoder
// - lyric_embed [S_lyric, 1024] : from CPU vocab lookup of lyric tokens
// - timbre_feats [S_ref, 64] : reference audio features (optional)
//
// Internal pipeline (ggml notation [ne0, ne1]):
// text_hidden -> Linear(1024->2048) -> [2048, S_text]
// lyric_embed -> Linear(1024->2048)+bias -> 8L -> norm -> [2048, S_lyric]
// timbre_feats -> Linear(64->2048)+bias -> 4L -> norm -> take frame[0] -> [2048, 1]
// Pack: cat(lyric, timbre[0:1], text_proj) -> [2048, S_total]
#pragma once
#include "qwen3.h"
// Lyric/Timbre encoder configs
static Qwen3Config qwen3_lyric_config() {
return {
/*hidden_size*/ 2048,
/*intermediate_size*/ 6144,
/*n_heads*/ 16,
/*n_kv_heads*/ 8,
/*head_dim*/ 128,
/*n_layers*/ 8,
/*rope_theta*/ 1000000.0f,
/*rms_norm_eps*/ 1e-6f,
/*is_causal*/ false,
};
}
static Qwen3Config qwen3_timbre_config() {
return {
/*hidden_size*/ 2048,
/*intermediate_size*/ 6144,
/*n_heads*/ 16,
/*n_kv_heads*/ 8,
/*head_dim*/ 128,
/*n_layers*/ 4,
/*rope_theta*/ 1000000.0f,
/*rms_norm_eps*/ 1e-6f,
/*is_causal*/ false,
};
}
// Struct
struct CondGGML {
// Lyric encoder (8L, H=2048)
Qwen3Config lyric_cfg;
Qwen3Layer lyric_layers[8];
struct ggml_tensor * lyric_embed_w; // [2048, 1024] ggml = Linear(1024->2048)
struct ggml_tensor * lyric_embed_b; // [2048]
struct ggml_tensor * lyric_norm; // [2048]
// Timbre encoder (4L, H=2048)
Qwen3Config timbre_cfg;
Qwen3Layer timbre_layers[4];
struct ggml_tensor * timbre_embed_w; // [2048, 64] ggml = Linear(64->2048)
struct ggml_tensor * timbre_embed_b; // [2048]
struct ggml_tensor * timbre_norm; // [2048]
// XL models prepend a learned CLS token to the timbre sequence.
// The CLS output (position 0) aggregates timbre info across all frames.
// 2B models skip this (the token exists in the GGUF but is unused).
struct ggml_tensor * timbre_cls; // [2048, 1, 1] or NULL
bool use_timbre_cls;
// Text projector: Linear(1024->2048) no bias
struct ggml_tensor * text_proj_w; // [2048, 1024] ggml
// Null condition embedding (for classifier-free guidance)
struct ggml_tensor * null_cond_emb; // [2048, 1, 1]
// Backend
ggml_backend_t backend;
ggml_backend_t cpu_backend;
ggml_backend_sched_t sched;
WeightCtx wctx;
};
// Init
static void cond_ggml_init_backend(CondGGML * m) {
BackendPair bp = backend_init("CondEncoder");
m->backend = bp.backend;
m->cpu_backend = bp.cpu_backend;
m->sched = backend_sched_new(bp, 8192);
}
// Load from ACEStep DiT GGUF
// gguf_path: path to the DiT .gguf file
// Tensors have prefix "encoder." for lyric/timbre, and "null_condition_emb"
static bool cond_ggml_load(CondGGML * m, const char * gguf_path) {
m->lyric_cfg = qwen3_lyric_config();
m->timbre_cfg = qwen3_timbre_config();
GGUFModel gf;
if (!gf_load(&gf, gguf_path)) {
fprintf(stderr, "[Load] FATAL: cannot load %s\n", gguf_path);
return false;
}
// XL models have encoder_hidden_size in GGUF metadata (2B models omit it).
// When present, the timbre encoder prepends a learned CLS token.
m->use_timbre_cls = (gf_get_u32(gf, "acestep.encoder_hidden_size") > 0);
// Count tensors:
// lyric: embed_w(1) + embed_b(1) + 8 layers x 11(88) + norm(1) = 91
// timbre: embed_w(1) + embed_b(1) + 4 layers x 11(44) + norm(1) + cls(0 or 1)
// text_proj(1) + null_cond(1) = 2
int n_tensors = 91 + 47 + 2 + (m->use_timbre_cls ? 1 : 0);
wctx_init(&m->wctx, n_tensors);
// Lyric encoder
m->lyric_embed_w = gf_load_tensor(&m->wctx, gf, "encoder.lyric_encoder.embed_tokens.weight");
m->lyric_embed_b = gf_load_tensor_f32(&m->wctx, gf, "encoder.lyric_encoder.embed_tokens.bias");
m->lyric_norm = gf_load_tensor_f32(&m->wctx, gf, "encoder.lyric_encoder.norm.weight");
fprintf(stderr, "[Load] LyricEncoder: %dL\n", m->lyric_cfg.n_layers);
for (int i = 0; i < m->lyric_cfg.n_layers; i++) {
char prefix[128];
snprintf(prefix, sizeof(prefix), "encoder.lyric_encoder.layers.%d", i);
qwen3_load_layer(&m->wctx, gf, &m->lyric_layers[i], prefix, i);
}
// Timbre encoder
m->timbre_embed_w = gf_load_tensor(&m->wctx, gf, "encoder.timbre_encoder.embed_tokens.weight");
m->timbre_embed_b = gf_load_tensor_f32(&m->wctx, gf, "encoder.timbre_encoder.embed_tokens.bias");
m->timbre_norm = gf_load_tensor_f32(&m->wctx, gf, "encoder.timbre_encoder.norm.weight");
fprintf(stderr, "[Load] TimbreEncoder: %dL\n", m->timbre_cfg.n_layers);
for (int i = 0; i < m->timbre_cfg.n_layers; i++) {
char prefix[128];
snprintf(prefix, sizeof(prefix), "encoder.timbre_encoder.layers.%d", i);
qwen3_load_layer(&m->wctx, gf, &m->timbre_layers[i], prefix, i);
}
// Timbre CLS token (XL only)
m->timbre_cls = NULL;
if (m->use_timbre_cls) {
m->timbre_cls = gf_load_tensor_f32(&m->wctx, gf, "encoder.timbre_encoder.special_token");
}
// Text projector + null condition
m->text_proj_w = gf_load_tensor(&m->wctx, gf, "encoder.text_projector.weight");
m->null_cond_emb = gf_load_tensor(&m->wctx, gf, "null_condition_emb");
if (!wctx_alloc(&m->wctx, m->backend)) {
gf_close(&gf);
return false;
}
gf_close(&gf);
fprintf(stderr, "[Load] CondEncoder: lyric(%dL), timbre(%dL%s), text_proj, null_cond\n", m->lyric_cfg.n_layers,
m->timbre_cfg.n_layers, m->use_timbre_cls ? ", CLS" : "");
return true;
}
// Forward
//
// Inputs (CPU float arrays):
// text_hidden: [1024 * S_text] from text encoder (Qwen3-Embedding)
// lyric_embed: [1024 * S_lyric] from CPU vocab lookup of lyric tokens
// timbre_feats: [64 * S_ref] reference audio features (NULL if none)
// S_text, S_lyric, S_ref sequence lengths
//
// Output:
// enc_hidden: resized to [2048 * S_total] float
// *out_enc_S: total sequence length
//
// Layout: all arrays in ggml order (ne[0]=dim contiguous, then sequence)
static void cond_ggml_forward(CondGGML * m,
const float * text_hidden, int S_text,
const float * lyric_embed, int S_lyric,
const float * timbre_feats, int S_ref,
std::vector<float> & enc_hidden,
int * out_enc_S) {
int H = 2048;
bool has_timbre = (timbre_feats != NULL && S_ref > 0);
int S_timbre = has_timbre ? S_ref + (m->use_timbre_cls ? 1 : 0) : 0;
// Graph context (generous fixed allocation)
size_t ctx_size = 4096 * ggml_tensor_overhead() + ggml_graph_overhead();
struct ggml_init_params gp = { ctx_size, NULL, true };
struct ggml_context * ctx = ggml_init(gp);
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 8192, false);
// Positions for lyric (bidirectional, 0..S_lyric-1)
struct ggml_tensor * lyric_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S_lyric);
ggml_set_name(lyric_pos, "lyric_pos");
ggml_set_input(lyric_pos);
// Path 1: Lyric encoder
struct ggml_tensor * t_lyric_in = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1024, S_lyric);
ggml_set_name(t_lyric_in, "lyric_in");
ggml_set_input(t_lyric_in);
// Linear embed: [1024, S_lyric] -> [2048, S_lyric]
struct ggml_tensor * lyric_h = qwen3_linear_bias(ctx, m->lyric_embed_w,
m->lyric_embed_b, t_lyric_in);
// Bidirectional sliding window mask for even layers (|i-j| <= 128)
// Python: layer_types = [sliding, full, sliding, full, ...]
// sliding_window = 128
struct ggml_tensor * lyric_slide_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S_lyric, S_lyric);
ggml_set_name(lyric_slide_mask, "lyric_slide_mask");
ggml_set_input(lyric_slide_mask);
// 8 layers with alternating masks + final norm
for (int i = 0; i < m->lyric_cfg.n_layers; i++) {
struct ggml_tensor * layer_mask = (i % 2 == 0) ? lyric_slide_mask : NULL;
lyric_h = qwen3_build_layer(ctx, m->lyric_cfg, &m->lyric_layers[i],
lyric_h, lyric_pos, layer_mask, S_lyric);
}
lyric_h = qwen3_rms_norm(ctx, lyric_h, m->lyric_norm, m->lyric_cfg.rms_norm_eps);
ggml_set_name(lyric_h, "lyric_out");
ggml_set_output(lyric_h);
// Path 2: Text projection
struct ggml_tensor * t_text_in = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1024, S_text);
ggml_set_name(t_text_in, "text_in");
ggml_set_input(t_text_in);
// Linear: [1024, S_text] -> [2048, S_text]
struct ggml_tensor * text_proj = qwen3_linear(ctx, m->text_proj_w, t_text_in);
ggml_set_name(text_proj, "text_proj_out");
ggml_set_output(text_proj);
// Path 3: Timbre encoder (optional)
struct ggml_tensor * timbre_out = NULL;
struct ggml_tensor * t_timbre_in = NULL;
struct ggml_tensor * timbre_pos = NULL;
struct ggml_tensor * timbre_slide_mask = NULL;
if (has_timbre) {
timbre_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, S_timbre);
ggml_set_name(timbre_pos, "timbre_pos");
ggml_set_input(timbre_pos);
t_timbre_in = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 64, S_ref);
ggml_set_name(t_timbre_in, "timbre_in");
ggml_set_input(t_timbre_in);
// Linear embed: [64, S_ref] -> [H, S_ref]
struct ggml_tensor * timbre_h = qwen3_linear_bias(ctx, m->timbre_embed_w,
m->timbre_embed_b, t_timbre_in);
// XL: prepend learned CLS token -> [H, S_ref+1]
// The CLS output at position 0 aggregates timbre across all frames.
if (m->use_timbre_cls) {
struct ggml_tensor * cls = ggml_reshape_2d(ctx, m->timbre_cls, H, 1);
timbre_h = ggml_concat(ctx, cls, timbre_h, 1);
}
// Bidirectional sliding window mask for even layers
timbre_slide_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, S_timbre, S_timbre);
ggml_set_name(timbre_slide_mask, "timbre_slide_mask");
ggml_set_input(timbre_slide_mask);
// 4 layers with alternating masks + final norm
for (int i = 0; i < m->timbre_cfg.n_layers; i++) {
struct ggml_tensor * layer_mask = (i % 2 == 0) ? timbre_slide_mask : NULL;
timbre_h = qwen3_build_layer(ctx, m->timbre_cfg, &m->timbre_layers[i],
timbre_h, timbre_pos, layer_mask, S_timbre);
}
timbre_h = qwen3_rms_norm(ctx, timbre_h, m->timbre_norm, m->timbre_cfg.rms_norm_eps);
// Take first position: [H, S_timbre] -> view [H, 1]
// 2B: first audio frame. XL: CLS token (aggregated timbre).
timbre_out = ggml_view_2d(ctx, timbre_h, H, 1,
timbre_h->nb[1], 0);
ggml_set_name(timbre_out, "timbre_out");
ggml_set_output(timbre_out);
}
// Build forward
ggml_build_forward_expand(gf, lyric_h);
ggml_build_forward_expand(gf, text_proj);
if (timbre_out) ggml_build_forward_expand(gf, timbre_out);
// Allocate and set inputs
ggml_backend_sched_alloc_graph(m->sched, gf);
ggml_backend_tensor_set(t_lyric_in, lyric_embed, 0, 1024 * S_lyric * sizeof(float));
ggml_backend_tensor_set(t_text_in, text_hidden, 0, 1024 * S_text * sizeof(float));
// Lyric positions
{
std::vector<int> pos(S_lyric);
for (int i = 0; i < S_lyric; i++) pos[i] = i;
ggml_backend_tensor_set(lyric_pos, pos.data(), 0, S_lyric * sizeof(int));
}
// Lyric sliding window mask: bidirectional, |i-j| <= 128
{
const int W = 128;
std::vector<uint16_t> mask_data(S_lyric * S_lyric);
for (int i = 0; i < S_lyric; i++) {
for (int j = 0; j < S_lyric; j++) {
int d = i - j; if (d < 0) d = -d;
mask_data[i * S_lyric + j] = ggml_fp32_to_fp16(d <= W ? 0.0f : -INFINITY);
}
}
ggml_backend_tensor_set(lyric_slide_mask, mask_data.data(), 0,
S_lyric * S_lyric * sizeof(uint16_t));
fprintf(stderr, "[CondEnc] Lyric sliding mask: %dx%d, window=%d\n", S_lyric, S_lyric, W);
}
if (has_timbre) {
ggml_backend_tensor_set(t_timbre_in, timbre_feats, 0, 64 * S_ref * sizeof(float));
std::vector<int> pos(S_timbre);
for (int i = 0; i < S_timbre; i++) {
pos[i] = i;
}
ggml_backend_tensor_set(timbre_pos, pos.data(), 0, S_timbre * sizeof(int));
// Timbre sliding window mask: bidirectional, |i-j| <= 128
const int W = 128;
std::vector<uint16_t> mask_data(S_timbre * S_timbre);
for (int i = 0; i < S_timbre; i++) {
for (int j = 0; j < S_timbre; j++) {
int d = i - j;
if (d < 0) {
d = -d;
}
mask_data[i * S_timbre + j] = ggml_fp32_to_fp16(d <= W ? 0.0f : -INFINITY);
}
}
ggml_backend_tensor_set(timbre_slide_mask, mask_data.data(), 0, S_timbre * S_timbre * sizeof(uint16_t));
fprintf(stderr, "[CondEnc] Timbre sliding mask: %dx%d, window=%d%s\n", S_timbre, S_timbre, W,
m->use_timbre_cls ? " (CLS)" : "");
}
// Compute
ggml_backend_sched_graph_compute(m->sched, gf);
// Read outputs and pack on CPU
// Pack order: lyric, timbre[0:1], text_proj
int S_timbre_out = has_timbre ? 1 : 0;
int S_total = S_lyric + S_timbre_out + S_text;
enc_hidden.resize(H * S_total);
*out_enc_S = S_total;
int offset = 0;
// Lyric: [2048, S_lyric]
ggml_backend_tensor_get(lyric_h, enc_hidden.data() + offset * H,
0, H * S_lyric * sizeof(float));
offset += S_lyric;
// Timbre: [2048, 1]
if (timbre_out) {
ggml_backend_tensor_get(timbre_out, enc_hidden.data() + offset * H,
0, H * 1 * sizeof(float));
offset += 1;
}
// Text projection: [2048, S_text]
ggml_backend_tensor_get(text_proj, enc_hidden.data() + offset * H,
0, H * S_text * sizeof(float));
offset += S_text;
fprintf(stderr, "[Encode] Packed: lyric=%d + timbre=%d + text=%d = %d tokens\n",
S_lyric, S_timbre_out, S_text, S_total);
ggml_backend_sched_reset(m->sched);
ggml_free(ctx);
}
// Free
static void cond_ggml_free(CondGGML * m) {
if (m->sched) ggml_backend_sched_free(m->sched);
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
wctx_free(&m->wctx);
*m = {};
}