mtmd: correct get_n_pos / get_decoder_pos (#22175)

This commit is contained in:
Xuan-Son Nguyen 2026-04-20 23:29:19 +02:00 committed by GitHub
parent cf8b0dbda9
commit 86f8daacfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 63 additions and 29 deletions

View file

@ -33,10 +33,16 @@ struct mtmd_bitmap {
bool is_audio = false; // true if the bitmap is audio
};
// position indexing for decoder model
enum mtmd_pos_type {
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
};
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL;
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
@ -45,7 +51,7 @@ struct mtmd_image_tokens {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
pos,
batch_f32.clone(),
id
};
@ -131,7 +137,7 @@ struct mtmd_context {
int n_threads;
std::string media_marker;
const int n_embd_text;
llama_rope_type decoder_rope;
mtmd_pos_type pos_type;
// these are not token, but strings used to mark the beginning and end of image/audio embeddings
std::string img_beg;
@ -168,8 +174,7 @@ struct mtmd_context {
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
media_marker (ctx_params.media_marker),
n_embd_text (llama_model_n_embd_inp(text_model)),
decoder_rope (llama_model_rope_type(text_model))
n_embd_text (llama_model_n_embd_inp(text_model))
{
if (ctx_params.image_marker != nullptr) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
@ -179,6 +184,22 @@ struct mtmd_context {
throw std::runtime_error("media_marker must not be empty");
}
auto decoder_rope_type = llama_model_rope_type(text_model);
switch (decoder_rope_type) {
case LLAMA_ROPE_TYPE_NORM:
case LLAMA_ROPE_TYPE_NEOX:
{
pos_type = MTMD_POS_TYPE_NORMAL;
} break;
case LLAMA_ROPE_TYPE_MROPE:
case LLAMA_ROPE_TYPE_IMROPE:
{
pos_type = MTMD_POS_TYPE_MROPE;
} break;
default:
throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type));
}
clip_context_params ctx_clip_params {
/* use_gpu */ ctx_params.use_gpu,
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
@ -779,12 +800,12 @@ struct mtmd_tokenizer {
// for Qwen2VL, we need this information for M-RoPE decoding positions
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
image_tokens->use_mrope_pos = true;
} else {
// other models, we only need the total number of tokens
image_tokens->nx = n_tokens;
image_tokens->ny = 1;
}
image_tokens->pos = ctx->pos_type;
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmap->id; // optional
@ -1016,7 +1037,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
return ctx->image_embd_v.data();
}
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk) {
auto proj_type = ctx->proj_type_v();
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
proj_type = ctx->proj_type_a();
@ -1030,20 +1051,19 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chu
}
}
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
return ctx->decoder_rope == LLAMA_ROPE_TYPE_MROPE
|| ctx->decoder_rope == LLAMA_ROPE_TYPE_IMROPE;
bool mtmd_decode_use_mrope(const mtmd_context * ctx) {
return ctx->pos_type;
}
bool mtmd_support_vision(mtmd_context * ctx) {
bool mtmd_support_vision(const mtmd_context * ctx) {
return ctx->ctx_v != nullptr;
}
bool mtmd_support_audio(mtmd_context * ctx) {
bool mtmd_support_audio(const mtmd_context * ctx) {
return ctx->ctx_a != nullptr;
}
int mtmd_get_audio_sample_rate(mtmd_context * ctx) {
int mtmd_get_audio_sample_rate(const mtmd_context * ctx) {
if (!ctx->ctx_a) {
return -1;
}
@ -1238,12 +1258,24 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) {
mtmd_decoder_pos pos;
// M-RoPE logic
// TODO: support other types of position encoding if needed
pos.t = pos_0;
pos.x = pos_0 + (i % image_tokens->nx);
pos.y = pos_0 + (i / image_tokens->nx);
pos.z = 0; // unused for now
switch (image_tokens->pos) {
case MTMD_POS_TYPE_MROPE:
{
pos.t = pos_0;
pos.x = pos_0 + (i % image_tokens->nx);
pos.y = pos_0 + (i / image_tokens->nx);
pos.z = 0; // unused for now
} break;
case MTMD_POS_TYPE_NORMAL:
{
pos.t = pos_0 + i;
pos.x = pos_0 + i;
pos.y = pos_0 + i;
pos.z = pos_0 + i;
} break;
default:
GGML_ABORT("invalid position type");
}
return pos;
}
@ -1252,12 +1284,14 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
}
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
switch (image_tokens->pos) {
case MTMD_POS_TYPE_MROPE:
return std::max(image_tokens->nx, image_tokens->ny);
case MTMD_POS_TYPE_NORMAL:
return image_tokens->n_tokens();
default:
GGML_ABORT("invalid position type");
}
return image_tokens->n_tokens();
}
// test function

View file

@ -112,20 +112,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
// whether we need to set non-causal mask before llama_decode
// if chunk is nullptr, we assume the default case where chunk is an image chunk
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
MTMD_API bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk);
// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
MTMD_API bool mtmd_decode_use_mrope(const mtmd_context * ctx);
// whether the current model supports vision input
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
MTMD_API bool mtmd_support_vision(const mtmd_context * ctx);
// whether the current model supports audio input
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
MTMD_API bool mtmd_support_audio(const mtmd_context * ctx);
// get audio sample rate in Hz, for example 16000 for Whisper
// return -1 if audio is not supported
MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx);
MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx);
// mtmd_bitmap
//