diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c index 02e762e6a..7a0ce593c 100644 --- a/tests/test-mtmd-c-api.c +++ b/tests/test-mtmd-c-api.c @@ -41,8 +41,10 @@ int main(void) { } else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk); size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); - size_t nx = mtmd_image_tokens_get_nx(image_tokens); - size_t ny = mtmd_image_tokens_get_ny(image_tokens); + // get position of the last token, which should be (nx - 1, ny - 1) + struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, n_tokens - 1); + size_t nx = pos.x + 1; + size_t ny = pos.y + 1; const char * id = mtmd_image_tokens_get_id(image_tokens); assert(n_tokens > 0); assert(nx > 0); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 2f45dab44..145b88cea 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -114,6 +114,13 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { return n_pos; } +void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) { + size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks); + for (size_t i = 0; i < n_tokens; i++) { + out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i); + } +} + // helper struct to make working with embd batch easier // note: this will be removed after llama_batch_ext refactoring struct decode_embd_batch { @@ -156,18 +163,15 @@ struct decode_embd_batch { } // M-RoPE for image - void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) { + void set_position_mrope_2d(llama_pos pos_0, const std::vector & rel_pos, llama_seq_id seq_id) { GGML_ASSERT(n_pos_per_embd == 4); - GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens); + GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens); seq_id_0[0] = seq_id; - for (int y = 0; y < ny; y++) { - for (int x = 0; x < nx; x++) { - int i = y * nx + x; - pos[i ] = pos_0; - pos[i + batch.n_tokens ] = pos_0 + y; - pos[i + batch.n_tokens * 2] = pos_0 + x; - pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused - } + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i ] = pos_0 + rel_pos[i].t; + pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y; + pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x; + pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused } for (int i = 0; i < batch.n_tokens; i++) { batch.n_seq_id[i] = 1; @@ -262,9 +266,10 @@ int32_t mtmd_helper_decode_image_chunk( LOG_ERR("failed to decode chunk: image tokens are null\n"); return -1; } - const int nx = mtmd_image_tokens_get_nx(image_tokens); - const int ny = mtmd_image_tokens_get_ny(image_tokens); - batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id); + const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); + std::vector rel_pos(n_tokens); + mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data()); + batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id); } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { batch_embd.set_position_mrope_1d(n_past, seq_id); } else { diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 5036b9244..8cadf42b4 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -47,6 +47,10 @@ MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks); // normally, n_pos is equal to n_tokens, but for M-RoPE it is different MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); +// helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE +// out_pos must have length == mtmd_helper_get_n_tokens(image) +MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, mtmd_decoder_pos * out_pos); + // helper function that automatically: // 1. run llama_decode() on text chunks // 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index dc2bde194..a56d3b35b 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -1249,6 +1249,14 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { return image_tokens->ny; } +mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i) { + mtmd_decoder_pos pos; + pos.t = 0; + pos.x = i % image_tokens->nx; + pos.y = i / image_tokens->nx; + return pos; +} + const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { return image_tokens->id.c_str(); } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 2ecf95694..c91bc0810 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -186,12 +186,25 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk); // the instance will be constructed via mtmd_tokenize() // it will be freed along with mtmd_input_chunk MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate -MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens); MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate // number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise) MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate +DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens), + "use mtmd_image_tokens_get_decoder_pos() instead"); +DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens), + "use mtmd_image_tokens_get_decoder_pos() instead"); + +struct mtmd_decoder_pos { + uint32_t t; + uint32_t x; + uint32_t y; +}; +// get position for decoder attention, to be used by M-RoPE models +// i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1 +// return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position) +MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i); + // tokenize an input text prompt and a list of bitmaps (images/audio) // the prompt must have the input image marker (default: "<__media__>") in it // the default marker is defined by mtmd_default_marker()