mtmd: add chunks and fix preproc for qwen3a (#23073)

* mtmd: add chunks and fix preproc for qwen3a

* add attn_mask

* limit mtmd_chunk size (avoid blow up memory)

* correct audio tokens

* re-order the set_input case

* remove attn_mask
This commit is contained in:
Xuan-Son Nguyen 2026-05-15 19:32:47 +02:00 committed by GitHub
parent 8be1786707
commit 72e60f500d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 200 additions and 71 deletions

View file

@ -11,6 +11,10 @@
#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)
struct build_vit_opts {
ggml_tensor * attn_mask = nullptr;
};
struct clip_graph {
const clip_model & model;
const clip_hparams & hparams;
@ -63,7 +67,8 @@ struct clip_graph {
norm_type norm_t,
ffn_op_type ffn_t,
ggml_tensor * learned_pos_embd,
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos);
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
const build_vit_opts & opts = {});
// build the input after conv2d (inp_raw --> patches)
// returns tensor with shape [n_embd, n_patches]

View file

@ -300,7 +300,8 @@ ggml_tensor * clip_graph::build_vit(
norm_type norm_t,
ffn_op_type ffn_t,
ggml_tensor * learned_pos_embd,
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
const build_vit_opts & opts
) {
if (learned_pos_embd) {
inp = ggml_add(ctx0, inp, learned_pos_embd);
@ -427,7 +428,7 @@ ggml_tensor * clip_graph::build_vit(
}
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il);
cb(cur, "attn_out", il);
}
@ -663,6 +664,9 @@ ggml_tensor * clip_graph::build_attn(
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
if (kq_mask) {
kq_mask = ggml_cast(ctx0, kq_mask, GGML_TYPE_F16);
}
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@ -3244,12 +3248,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_QWEN3A:
{
// 3x stride-2 conv2d: each step is floor((n-1)/2)+1
int n = img->nx;
n = (n - 1) / 2 + 1;
n = (n - 1) / 2 + 1;
n = (n - 1) / 2 + 1;
n_patches = n;
// chunk_size=100 frames --> 3x stride-2 conv2d --> 13 tokens per chunk
const int chunk_size = 100;
const int tokens_per_chunk = 13;
n_patches = (img->nx / chunk_size) * tokens_per_chunk;
} break;
case PROJECTOR_TYPE_GLMA:
{
@ -4292,21 +4294,6 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
return ctx->model.modality == CLIP_MODALITY_AUDIO;
}
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
switch (ctx->proj_type()) {
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_MERALION:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
return true;
default:
return false;
}
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);

View file

@ -115,7 +115,6 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
struct clip_cap {
bool has_vision;

View file

@ -1,68 +1,88 @@
#include "models.h"
ggml_cgraph * clip_graph_qwen3a::build() {
// Ref implementation: https://github.com/QwenLM/Qwen3-ASR/blob/main/qwen_asr/core/transformers_backend/modeling_qwen3_asr.py
// inp_raw: [n_frames, n_mel, 1] (nx=n_frames, ny=n_mel)
ggml_tensor * inp = build_inp_raw(1);
// conv2d block
// TODO: do we need to split by chunks of n_window each like on transformers impl?
const int64_t n_frames = inp->ne[0]; // total frames, padded to multiple of chunk_size
const int64_t n_mel = inp->ne[1]; // 128
const int64_t chunk_size = 100; // n_window * 2 (n_window=50 from model config)
const int64_t n_chunks = n_frames / chunk_size;
GGML_ASSERT(n_frames % chunk_size == 0); // preprocessor should already pad the input
GGML_ASSERT(inp->type == GGML_TYPE_F32);
// View mel spectrogram as batched 100-frame chunks: [chunk_size, n_mel, 1, n_chunks]
inp = ggml_view_4d(ctx0, inp,
chunk_size, n_mel, 1, n_chunks,
n_frames * (int64_t)sizeof(float), // nb[1]: stride over mel bins
chunk_size * (int64_t)sizeof(float), // nb[2]: stride for C=1 (unused)
chunk_size * (int64_t)sizeof(float), // nb[3]: stride over chunks
0);
inp = ggml_cont(ctx0, inp);
cb(inp, "inp_chunks", -1);
// 3 x conv2d + gelu
{
inp = ggml_conv_2d(ctx0, model.conv2d_1_w, inp, 2, 2, 1, 1, 1, 1);
inp = ggml_add(ctx0, inp, model.conv2d_1_b);
inp = ggml_gelu_erf(ctx0, inp);
// conv output [OW, OH, C_out, n_chunks]
auto conv_block = [&](ggml_tensor * x, ggml_tensor * w, ggml_tensor * b) {
x = ggml_conv_2d(ctx0, w, x, 2, 2, 1, 1, 1, 1);
if (b) {
x = ggml_add(ctx0, x, ggml_reshape_4d(ctx0, b, 1, 1, x->ne[2], 1));
}
return ggml_gelu_erf(ctx0, x);
};
inp = ggml_conv_2d(ctx0, model.conv2d_2_w, inp, 2, 2, 1, 1, 1, 1);
inp = ggml_add(ctx0, inp, model.conv2d_2_b);
inp = ggml_gelu_erf(ctx0, inp);
inp = ggml_conv_2d(ctx0, model.conv2d_3_w, inp, 2, 2, 1, 1, 1, 1);
inp = ggml_add(ctx0, inp, model.conv2d_3_b);
inp = ggml_gelu_erf(ctx0, inp);
// inp [n_pos, n_mels/8, channels, 1] (W, H, C, N)
inp = conv_block(inp, model.conv2d_1_w, model.conv2d_1_b);
inp = conv_block(inp, model.conv2d_2_w, model.conv2d_2_b);
inp = conv_block(inp, model.conv2d_3_w, model.conv2d_3_b);
// inp: [OW=13, OH=16, OC=480, n_chunks]
cb(inp, "after_conv_blocks", -1);
const int64_t n_pos_after_conv = inp->ne[0];
const int64_t n_mel_after_conv = inp->ne[1]; // 128/8 = 16
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 3, 1));
inp = ggml_reshape_2d(ctx0, inp, n_pos_after_conv, n_mel_after_conv * inp->ne[3]); // [n_pos, 7680]
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [7680, n_pos]
// project to n_embd
inp = ggml_mul_mat(ctx0, model.conv_out_w, inp);
if (model.conv_out_b) {
inp = ggml_add(ctx0, inp, model.conv_out_b);
}
cb(inp, "after_conv_out", -1);
}
auto n_pos = inp->ne[1];
// permute [OW=25, OH=16, OC=480, n_chunks] -> [OH=16, OC=480, OW=25, n_chunks]
// reshape to [OH*OC=7680, OW*n_chunks]
// feature index h+16*c = c*16+f (matches python code)
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 2, 0, 1, 3));
inp = ggml_reshape_2d(ctx0, inp, inp->ne[0] * inp->ne[1], inp->ne[2] * inp->ne[3]);
ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
model.position_embeddings->ne[0], n_pos,
model.position_embeddings->nb[1], 0
);
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
pos_embd_selected,
nullptr);
// Project to d_model: [d_model, 25*n_chunks]
inp = ggml_mul_mat(ctx0, model.conv_out_w, inp);
if (model.conv_out_b) {
inp = ggml_add(ctx0, inp, model.conv_out_b);
}
cb(inp, "after_conv_out", -1);
const int64_t n_pos = inp->ne[1]; // 25 * n_chunks
// Per-chunk positional embeddings: repeat pos[0:13] for each chunk
// (position indices reset 0..12 per chunk, not sequential across chunks)
{
const int64_t tokens_per_chunk = n_pos / n_chunks; // 13
ggml_tensor * pos_tmp = ggml_view_2d(ctx0, model.position_embeddings,
model.position_embeddings->ne[0], tokens_per_chunk,
model.position_embeddings->nb[1], 0);
ggml_tensor * tgt = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,
model.position_embeddings->ne[0], n_pos);
inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, pos_tmp, tgt));
}
ggml_tensor * cur = build_vit(inp, n_pos,
NORM_TYPE_NORMAL, hparams.ffn_op,
nullptr, // pos embd already added above
nullptr);
cb(cur, "after_transformer", -1);
// projector
// MLP projector
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU_ERF,
-1);
FFN_GELU_ERF, -1);
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}

View file

@ -609,6 +609,110 @@ bool mtmd_audio_preprocessor_whisper::preprocess(const float * s
return true;
}
//
// mtmd_audio_preprocessor_qwen3a
//
// Matches the Python WhisperFeatureExtractor called with truncation=False:
// - reflection padding of n_fft/2 samples at each end (center=True)
// - Whisper-style log10 + (max-8)/4 normalization applied to full audio
// - output split into ≤30s (3000 mel frames) windows, each padded to a
// multiple of 200 frames (n_window * 2) for the cgraph batch view
//
void mtmd_audio_preprocessor_qwen3a::initialize() {
cache.fill_sin_cos_table(hparams.audio_n_fft);
cache.fill_hann_window(hparams.audio_window_len, true);
cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
}
bool mtmd_audio_preprocessor_qwen3a::preprocess(const float * samples,
size_t n_samples,
std::vector<mtmd_audio_mel> & output) {
if (n_samples == 0) {
return false;
}
GGML_ASSERT(!cache.sin_vals.empty());
GGML_ASSERT(!cache.cos_vals.empty());
GGML_ASSERT(!cache.filters.data.empty());
// Reflection-pad n_fft/2 samples at each end, matching WhisperFeatureExtractor center=True
const int pad = hparams.audio_n_fft / 2; // = 200
std::vector<float> padded(n_samples + 2 * pad, 0.0f);
// Reflect start: padded[0..pad-1] = samples[pad..1] (reversed)
for (int i = 0; i < pad; i++) {
int src = pad - i; // samples[pad], samples[pad-1], ..., samples[1]
padded[i] = (src < (int)n_samples) ? samples[src] : 0.0f;
}
std::copy(samples, samples + n_samples, padded.begin() + pad);
// Reflect end: padded[n+pad..n+2*pad-1] = samples[n-2..n-pad-1] (reversed)
for (int i = 0; i < pad; i++) {
int src = (int)n_samples - 2 - i; // samples[n-2], samples[n-3], ...
padded[n_samples + pad + i] = (src >= 0) ? samples[src] : 0.0f;
}
filter_params params;
params.n_mel = hparams.n_mel_bins;
params.n_fft_bins = 1 + (hparams.audio_n_fft / 2);
params.hann_window_size = hparams.audio_window_len;
params.hop_length = hparams.audio_hop_len;
params.sample_rate = hparams.audio_sample_rate;
params.no_padding = true; // reflection padding already applied above
params.use_natural_log = false; // log10
mtmd_audio_mel mel_full;
bool ok = log_mel_spectrogram(padded.data(), (int)padded.size(), 4, params, cache, mel_full);
if (!ok) {
return false;
}
// Whisper-style normalization: clamp to (max - 8), scale to [-1, 1]
{
double mmax = -1e20;
for (float v : mel_full.data) {
if (v > mmax) mmax = v;
}
mmax -= 8.0;
for (float & v : mel_full.data) {
v = (std::max((double)v, mmax) + 4.0) / 4.0;
}
}
// The effective frame count: center-padded STFT gives ~n_samples/hop_length frames.
// We take min(mel_full.n_len, n_samples/hop + 1) to avoid including excess frames.
const int n_eff = std::min(mel_full.n_len,
(int)(n_samples / hparams.audio_hop_len) + 1);
// Split into inference windows matching n_window_infer=800 from model config.
// Each window is padded to the next multiple of chunk_size for the cgraph.
// The mtmd caller loops over output entries, so long audio is handled automatically.
const int chunk_size = 100; // conv sub-chunk size (n_window * 2, n_window=50)
const int window_size = 800; // mel frames per forward pass (n_window_infer=800)
for (int off = 0; off < n_eff; off += window_size) {
const int win_eff = std::min(window_size, n_eff - off);
const int n_chunks = (win_eff + chunk_size - 1) / chunk_size;
const int n_padded = n_chunks * chunk_size;
mtmd_audio_mel out;
out.n_mel = mel_full.n_mel;
out.n_len = n_padded;
out.n_len_org = win_eff;
out.data.assign(out.n_mel * out.n_len, 0.0f);
for (int m = 0; m < out.n_mel; m++) {
const int copy_len = std::min(win_eff, mel_full.n_len - off);
if (copy_len > 0) {
std::copy(mel_full.data.begin() + (size_t)m * mel_full.n_len + off,
mel_full.data.begin() + (size_t)m * mel_full.n_len + off + copy_len,
out.data.begin() + (size_t)m * out.n_len);
}
}
output.push_back(std::move(out));
}
return true;
}
//
// mtmd_audio_preprocessor_conformer
//

View file

@ -96,6 +96,15 @@ struct mtmd_audio_preprocessor_gemma4a : mtmd_audio_preprocessor {
mtmd_audio_cache cache;
};
struct mtmd_audio_preprocessor_qwen3a : mtmd_audio_preprocessor {
mtmd_audio_preprocessor_qwen3a(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
void initialize() override;
bool preprocess(const float * samples, size_t n_samples, std::vector<mtmd_audio_mel> & output) override;
private:
mtmd_audio_cache cache;
};
//
// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
//

View file

@ -515,7 +515,6 @@ struct mtmd_context {
// set preprocessor
switch (proj) {
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_QWEN3A:
case PROJECTOR_TYPE_QWEN25O:
{
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
@ -523,6 +522,12 @@ struct mtmd_context {
aud_end = "<|audio_eos|>";
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
} break;
case PROJECTOR_TYPE_QWEN3A:
{
aud_beg = "<|audio_start|>";
aud_end = "<|audio_end|>";
audio_preproc = std::make_unique<mtmd_audio_preprocessor_qwen3a>(ctx_a);
} break;
case PROJECTOR_TYPE_VOXTRAL:
{
// [BEGIN_AUDIO] ... (embeddings) ...