mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
mtmd: correct get_n_pos / get_decoder_pos (#22175)
This commit is contained in:
parent
cf8b0dbda9
commit
86f8daacfe
2 changed files with 63 additions and 29 deletions
|
|
@ -33,10 +33,16 @@ struct mtmd_bitmap {
|
|||
bool is_audio = false; // true if the bitmap is audio
|
||||
};
|
||||
|
||||
// position indexing for decoder model
|
||||
enum mtmd_pos_type {
|
||||
MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens
|
||||
MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
||||
mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL;
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
|
@ -45,7 +51,7 @@ struct mtmd_image_tokens {
|
|||
return mtmd_image_tokens{
|
||||
nx,
|
||||
ny,
|
||||
use_mrope_pos,
|
||||
pos,
|
||||
batch_f32.clone(),
|
||||
id
|
||||
};
|
||||
|
|
@ -131,7 +137,7 @@ struct mtmd_context {
|
|||
int n_threads;
|
||||
std::string media_marker;
|
||||
const int n_embd_text;
|
||||
llama_rope_type decoder_rope;
|
||||
mtmd_pos_type pos_type;
|
||||
|
||||
// these are not token, but strings used to mark the beginning and end of image/audio embeddings
|
||||
std::string img_beg;
|
||||
|
|
@ -168,8 +174,7 @@ struct mtmd_context {
|
|||
print_timings(ctx_params.print_timings),
|
||||
n_threads (ctx_params.n_threads),
|
||||
media_marker (ctx_params.media_marker),
|
||||
n_embd_text (llama_model_n_embd_inp(text_model)),
|
||||
decoder_rope (llama_model_rope_type(text_model))
|
||||
n_embd_text (llama_model_n_embd_inp(text_model))
|
||||
{
|
||||
if (ctx_params.image_marker != nullptr) {
|
||||
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
|
||||
|
|
@ -179,6 +184,22 @@ struct mtmd_context {
|
|||
throw std::runtime_error("media_marker must not be empty");
|
||||
}
|
||||
|
||||
auto decoder_rope_type = llama_model_rope_type(text_model);
|
||||
switch (decoder_rope_type) {
|
||||
case LLAMA_ROPE_TYPE_NORM:
|
||||
case LLAMA_ROPE_TYPE_NEOX:
|
||||
{
|
||||
pos_type = MTMD_POS_TYPE_NORMAL;
|
||||
} break;
|
||||
case LLAMA_ROPE_TYPE_MROPE:
|
||||
case LLAMA_ROPE_TYPE_IMROPE:
|
||||
{
|
||||
pos_type = MTMD_POS_TYPE_MROPE;
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type));
|
||||
}
|
||||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
|
||||
|
|
@ -779,12 +800,12 @@ struct mtmd_tokenizer {
|
|||
// for Qwen2VL, we need this information for M-RoPE decoding positions
|
||||
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
|
||||
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
|
||||
image_tokens->use_mrope_pos = true;
|
||||
} else {
|
||||
// other models, we only need the total number of tokens
|
||||
image_tokens->nx = n_tokens;
|
||||
image_tokens->ny = 1;
|
||||
}
|
||||
image_tokens->pos = ctx->pos_type;
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
image_tokens->id = bitmap->id; // optional
|
||||
|
||||
|
|
@ -1016,7 +1037,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
|
|||
return ctx->image_embd_v.data();
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
auto proj_type = ctx->proj_type_v();
|
||||
if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
proj_type = ctx->proj_type_a();
|
||||
|
|
@ -1030,20 +1051,19 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chu
|
|||
}
|
||||
}
|
||||
|
||||
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
|
||||
return ctx->decoder_rope == LLAMA_ROPE_TYPE_MROPE
|
||||
|| ctx->decoder_rope == LLAMA_ROPE_TYPE_IMROPE;
|
||||
bool mtmd_decode_use_mrope(const mtmd_context * ctx) {
|
||||
return ctx->pos_type;
|
||||
}
|
||||
|
||||
bool mtmd_support_vision(mtmd_context * ctx) {
|
||||
bool mtmd_support_vision(const mtmd_context * ctx) {
|
||||
return ctx->ctx_v != nullptr;
|
||||
}
|
||||
|
||||
bool mtmd_support_audio(mtmd_context * ctx) {
|
||||
bool mtmd_support_audio(const mtmd_context * ctx) {
|
||||
return ctx->ctx_a != nullptr;
|
||||
}
|
||||
|
||||
int mtmd_get_audio_sample_rate(mtmd_context * ctx) {
|
||||
int mtmd_get_audio_sample_rate(const mtmd_context * ctx) {
|
||||
if (!ctx->ctx_a) {
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -1238,12 +1258,24 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
|
|||
|
||||
mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) {
|
||||
mtmd_decoder_pos pos;
|
||||
// M-RoPE logic
|
||||
// TODO: support other types of position encoding if needed
|
||||
pos.t = pos_0;
|
||||
pos.x = pos_0 + (i % image_tokens->nx);
|
||||
pos.y = pos_0 + (i / image_tokens->nx);
|
||||
pos.z = 0; // unused for now
|
||||
switch (image_tokens->pos) {
|
||||
case MTMD_POS_TYPE_MROPE:
|
||||
{
|
||||
pos.t = pos_0;
|
||||
pos.x = pos_0 + (i % image_tokens->nx);
|
||||
pos.y = pos_0 + (i / image_tokens->nx);
|
||||
pos.z = 0; // unused for now
|
||||
} break;
|
||||
case MTMD_POS_TYPE_NORMAL:
|
||||
{
|
||||
pos.t = pos_0 + i;
|
||||
pos.x = pos_0 + i;
|
||||
pos.y = pos_0 + i;
|
||||
pos.z = pos_0 + i;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("invalid position type");
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
|
|
@ -1252,12 +1284,14 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
|||
}
|
||||
|
||||
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
||||
if (image_tokens->use_mrope_pos) {
|
||||
// for M-RoPE, temporal dimension = max(t,h,w)
|
||||
// t is omitted as we don't support video input
|
||||
return std::max(image_tokens->nx, image_tokens->ny);
|
||||
switch (image_tokens->pos) {
|
||||
case MTMD_POS_TYPE_MROPE:
|
||||
return std::max(image_tokens->nx, image_tokens->ny);
|
||||
case MTMD_POS_TYPE_NORMAL:
|
||||
return image_tokens->n_tokens();
|
||||
default:
|
||||
GGML_ABORT("invalid position type");
|
||||
}
|
||||
return image_tokens->n_tokens();
|
||||
}
|
||||
|
||||
// test function
|
||||
|
|
|
|||
|
|
@ -112,20 +112,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
|
|||
|
||||
// whether we need to set non-causal mask before llama_decode
|
||||
// if chunk is nullptr, we assume the default case where chunk is an image chunk
|
||||
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk);
|
||||
MTMD_API bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk);
|
||||
|
||||
// whether the current model use M-RoPE for llama_decode
|
||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||
MTMD_API bool mtmd_decode_use_mrope(const mtmd_context * ctx);
|
||||
|
||||
// whether the current model supports vision input
|
||||
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
|
||||
MTMD_API bool mtmd_support_vision(const mtmd_context * ctx);
|
||||
|
||||
// whether the current model supports audio input
|
||||
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
|
||||
MTMD_API bool mtmd_support_audio(const mtmd_context * ctx);
|
||||
|
||||
// get audio sample rate in Hz, for example 16000 for Whisper
|
||||
// return -1 if audio is not supported
|
||||
MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx);
|
||||
MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx);
|
||||
|
||||
// mtmd_bitmap
|
||||
//
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue