mtmd: fit_params now take into account mmproj (#21489)

* mtmd: fit_params now take into account mmproj

* rename alloc_compute_meta to reserve_compute_meta

* rm unused functions

* add ggml_backend_dev_t support

* add debug log
This commit is contained in:
Xuan-Son Nguyen 2026-05-20 11:27:44 +02:00 committed by GitHub
parent 7e50ef7d79
commit e2b129e1bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 183 additions and 80 deletions

View file

@ -162,8 +162,14 @@ struct clip_ctx {
bool debug_output_embeddings = false;
// for measuring memory usage
bool no_alloc = false;
std::map<ggml_backend_dev_t, size_t> mem_usage;
std::map<ggml_backend_dev_t, size_t> mem_compute;
clip_ctx(clip_context_params & ctx_params) {
flash_attn_type = ctx_params.flash_attn_type;
no_alloc = ctx_params.no_alloc;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
throw std::runtime_error("failed to initialize CPU backend");
@ -1688,6 +1694,8 @@ struct clip_model_loader {
ggml_set_name(data_tensor, cur->name);
loaded_tensor_names.insert(name);
cur = data_tensor;
// add to weight memory counter
ctx_clip.mem_usage[ggml_backend_get_device(ctx_clip.backend)] += ggml_nbytes(cur);
}
return cur;
};
@ -2602,7 +2610,7 @@ struct clip_model_loader {
}
// load data
{
if (!ctx_clip.no_alloc) {
std::vector<uint8_t> read_buf;
// alloc memory and offload data
@ -2676,7 +2684,7 @@ struct clip_model_loader {
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
info = alloc_compute_meta(ctx_clip, batch);
info = reserve_compute_meta(ctx_clip, batch);
if (!info.fattn && info.fattn_op) {
auto op = info.fattn_op;
LOG_WRN("%s: *****************************************************************\n", __func__);
@ -2695,10 +2703,10 @@ struct clip_model_loader {
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip, batch);
reserve_compute_meta(ctx_clip, batch);
}
} else {
info = alloc_compute_meta(ctx_clip, batch);
info = reserve_compute_meta(ctx_clip, batch);
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
@ -2737,12 +2745,14 @@ struct clip_model_loader {
}
}
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
// only initialize backend buffers, but do not allocate them yet
static support_info_graph reserve_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
ctx_clip.mem_compute.clear();
for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
ggml_backend_t backend = ctx_clip.backend_ptrs[i];
ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i];
@ -2752,6 +2762,7 @@ struct clip_model_loader {
ggml_backend_buft_name(buft),
size / 1024.0 / 1024.0);
}
ctx_clip.mem_compute[ggml_backend_get_device(backend)] += size;
}
const int n_splits = ggml_backend_sched_get_n_splits(ctx_clip.sched.get());
@ -4266,22 +4277,6 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
}
}
int clip_is_minicpmv(const struct clip_ctx * ctx) {
// TODO: remove this function
if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
return ctx->model.hparams.minicpmv_version;
}
if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV4_6) {
return 46;
}
return 0;
}
bool clip_is_glm(const struct clip_ctx * ctx) {
// TODO: remove this function
return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
}
bool clip_is_llava(const struct clip_ctx * ctx) {
return ctx->model.hparams.has_llava_projector;
}
@ -4330,6 +4325,14 @@ const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
return &ctx->model.hparams;
}
std::map<ggml_backend_dev_t, size_t> clip_get_mem_usage(const struct clip_ctx * ctx) {
std::map<ggml_backend_dev_t, size_t> result = ctx->mem_usage;
for (auto & [dev, size] : ctx->mem_compute) {
result[dev] += size;
}
return result;
}
//
// API for debugging
//

View file

@ -6,6 +6,8 @@
#include <stddef.h>
#include <stdint.h>
#include <map>
// !!! Internal header, to be used by mtmd only !!!
#define MTMD_INTERNAL_HEADER
@ -40,6 +42,7 @@ struct clip_context_params {
bool warmup;
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
bool no_alloc;
};
struct clip_init_result {
@ -102,8 +105,6 @@ struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
int clip_is_minicpmv(const struct clip_ctx * ctx);
bool clip_is_glm(const struct clip_ctx * ctx);
bool clip_is_llava(const struct clip_ctx * ctx);
// note for contributor: this clip_is_(model) pattern is deprecated
// do NOT add new functions like this
@ -116,6 +117,8 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
std::map<ggml_backend_dev_t, size_t> clip_get_mem_usage(const struct clip_ctx * ctx);
struct clip_cap {
bool has_vision;
bool has_audio;

View file

@ -21,6 +21,7 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <climits>
#include <vector>
// represents raw image data, layout is RGBRGBRGB...
@ -139,13 +140,13 @@ mtmd_context_params mtmd_context_params_default() {
struct mtmd_context {
struct clip_ctx * ctx_v; // vision
struct clip_ctx * ctx_a; // audio
const struct llama_model * text_model;
std::vector<float> image_embd_v; // image embedding vector
bool print_timings;
int n_threads;
std::string media_marker;
const int n_embd_text;
const int n_embd_text = -1; // -1 means llm context not provided, skip checking this
const llama_vocab * vocab = nullptr; // can be nullptr if text_model is not provided
mtmd_pos_type pos_type;
// these are not token, but strings used to mark the beginning and end of image/audio embeddings
@ -178,12 +179,13 @@ struct mtmd_context {
mtmd_context(const char * mmproj_fname,
const llama_model * text_model,
const mtmd_context_params & ctx_params) :
text_model (text_model),
const mtmd_context_params & ctx_params,
bool no_alloc = false) :
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
media_marker (ctx_params.media_marker),
n_embd_text (llama_model_n_embd_inp(text_model))
n_embd_text (text_model ? llama_model_n_embd_inp(text_model) : -1),
vocab (text_model ? llama_model_get_vocab(text_model) : nullptr)
{
if (ctx_params.image_marker != nullptr) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
@ -193,21 +195,23 @@ struct mtmd_context {
throw std::runtime_error("media_marker must not be empty");
}
auto decoder_rope_type = llama_model_rope_type(text_model);
switch (decoder_rope_type) {
case LLAMA_ROPE_TYPE_NONE:
case LLAMA_ROPE_TYPE_NORM:
case LLAMA_ROPE_TYPE_NEOX:
{
pos_type = MTMD_POS_TYPE_NORMAL;
} break;
case LLAMA_ROPE_TYPE_MROPE:
case LLAMA_ROPE_TYPE_IMROPE:
{
pos_type = MTMD_POS_TYPE_MROPE;
} break;
default:
throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type));
if (text_model) {
auto decoder_rope_type = llama_model_rope_type(text_model);
switch (decoder_rope_type) {
case LLAMA_ROPE_TYPE_NONE:
case LLAMA_ROPE_TYPE_NORM:
case LLAMA_ROPE_TYPE_NEOX:
{
pos_type = MTMD_POS_TYPE_NORMAL;
} break;
case LLAMA_ROPE_TYPE_MROPE:
case LLAMA_ROPE_TYPE_IMROPE:
{
pos_type = MTMD_POS_TYPE_MROPE;
} break;
default:
throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type));
}
}
clip_context_params ctx_clip_params {
@ -218,6 +222,7 @@ struct mtmd_context {
/* warmup */ ctx_params.warmup,
/* cb_eval */ ctx_params.cb_eval,
/* cb_eval_user_data */ ctx_params.cb_eval_user_data,
/* no_alloc */ no_alloc,
};
auto res = clip_init(mmproj_fname, ctx_clip_params);
@ -241,7 +246,7 @@ struct mtmd_context {
// since we already validate n_embd of vision and audio mmproj,
// we can safely assume that they are the same
int n_embd_clip = clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
if (n_embd_text != n_embd_clip) {
if (n_embd_text > 0 && n_embd_text != n_embd_clip) {
throw std::runtime_error(string_format(
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
"hint: you may be using wrong mmproj\n",
@ -279,7 +284,7 @@ struct mtmd_context {
} break;
case PROJECTOR_TYPE_MINICPMV:
{
int minicpmv_version = clip_is_minicpmv(ctx_v);
int minicpmv_version = clip_get_hparams(ctx_v)->minicpmv_version;
if (minicpmv_version == 2) {
// minicpmv 2.5 format:
// <image> (overview) </image><slice><image> (slice) </image><image> (slice) </image>\n ... </slice>
@ -594,7 +599,11 @@ struct mtmd_context {
private:
llama_token lookup_token(const std::string & token_text) {
const llama_vocab * vocab = llama_model_get_vocab(text_model);
if (vocab == nullptr) {
// TODO @ngxson : this case is currently hit by mtmd_get_memory_usage
// but we should reconsider this if this case is needed in other places in the future
return LLAMA_TOKEN_NULL;
}
const int n_vocab = llama_vocab_n_tokens(vocab);
for (int i = 0; i < n_vocab; i++) {
if (token_to_piece(vocab, i, true) == token_text) {
@ -605,6 +614,9 @@ private:
}
std::string token_to_piece(const llama_vocab * vocab, llama_token token, bool special) {
if (vocab == nullptr) {
throw std::runtime_error("llama_vocab is not provided");
}
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
@ -653,7 +665,7 @@ struct mtmd_tokenizer {
add_special = text->add_special;
parse_special = text->parse_special;
input_text = text->text;
vocab = llama_model_get_vocab(ctx->text_model);
vocab = ctx->vocab;
}
int32_t tokenize(mtmd_input_chunks * output) {
@ -679,27 +691,29 @@ struct mtmd_tokenizer {
}
}
if (add_special && llama_vocab_get_add_bos(vocab)) {
// if first chunk is text, we add BOS token to first text chunk
// otherwise, create a new text chunk with BOS token
if (!cur.entries.empty() && cur.entries[0].type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
// add BOS token to the beginning of first text chunk
cur.entries[0].tokens_text.insert(cur.entries[0].tokens_text.begin(), llama_vocab_bos(vocab));
} else {
// create a new text chunk with BOS token at the beginning
mtmd_input_chunk bos_chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
{llama_vocab_bos(vocab)},
nullptr, // image tokens
nullptr, // audio tokens
};
cur.entries.insert(cur.entries.begin(), std::move(bos_chunk));
if (vocab != nullptr) {
if (add_special && llama_vocab_get_add_bos(vocab)) {
// if first chunk is text, we add BOS token to first text chunk
// otherwise, create a new text chunk with BOS token
if (!cur.entries.empty() && cur.entries[0].type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
// add BOS token to the beginning of first text chunk
cur.entries[0].tokens_text.insert(cur.entries[0].tokens_text.begin(), llama_vocab_bos(vocab));
} else {
// create a new text chunk with BOS token at the beginning
mtmd_input_chunk bos_chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
{llama_vocab_bos(vocab)},
nullptr, // image tokens
nullptr, // audio tokens
};
cur.entries.insert(cur.entries.begin(), std::move(bos_chunk));
}
}
}
if (add_special && llama_vocab_get_add_eos(vocab)) {
// if last chunk is text, we add EOS token to it
add_text({llama_vocab_eos(vocab)});
if (add_special && llama_vocab_get_add_eos(vocab)) {
// if last chunk is text, we add EOS token to it
add_text({llama_vocab_eos(vocab)});
}
}
if (i_bm != bitmaps.size()) {
@ -714,6 +728,9 @@ struct mtmd_tokenizer {
}
void add_text(const std::string & txt, bool parse_special) {
if (vocab == nullptr) {
throw std::runtime_error("llama_vocab is not provided");
}
LOG_DBG("%s: %s\n", __func__, txt.c_str());
auto tokens = mtmd_tokenize_text_internal(vocab, txt, /* add_special */ false, parse_special);
add_text(tokens);
@ -1002,10 +1019,16 @@ struct mtmd_tokenizer {
const std::string & text,
bool add_special,
bool parse_special) {
if (vocab == nullptr) {
throw std::runtime_error("llama_vocab is not provided");
}
// upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens == std::numeric_limits<int32_t>::min()) {
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
}
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
@ -1067,8 +1090,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
bool ok = false;
if (clip_is_llava(ctx_clip)
|| clip_is_minicpmv(ctx_clip)
|| clip_is_glm(ctx_clip)
|| proj_type == PROJECTOR_TYPE_MINICPMV
|| proj_type == PROJECTOR_TYPE_GLM_EDGE
|| proj_type == PROJECTOR_TYPE_INTERNVL) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
@ -1542,3 +1565,36 @@ void mtmd_debug_preprocess_audio(mtmd_context * ctx, const std::vector<float> &
}
}
}
static void stub_log_callback(enum ggml_log_level, const char *, void *) {
// do nothing
}
std::map<ggml_backend_dev_t, size_t> mtmd_get_memory_usage(const char * mmproj_fname,
struct mtmd_context_params ctx_params) {
mtmd::context_ptr ctx;
auto saved_log_callback = g_logger_state.log_callback;
auto saved_log_user_data = g_logger_state.log_callback_user_data;
try {
mtmd_log_set(stub_log_callback, nullptr); // suppress logging
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params));
mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback
std::map<ggml_backend_dev_t, size_t> total_mem;
auto merge = [&](const struct clip_ctx * c) {
for (auto & [dev, size] : clip_get_mem_usage(c)) {
total_mem[dev] += size;
}
};
if (ctx->ctx_v) {
merge(ctx->ctx_v);
}
if (ctx->ctx_a) {
merge(ctx->ctx_a);
}
return total_mem;
} catch (const std::exception & e) {
mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback
LOG_ERR("%s: error: %s\n", __func__, e.what());
return {};
}
}

View file

@ -9,6 +9,7 @@
#include <stdbool.h>
#ifdef __cplusplus
#include <map>
#include <string>
#include <vector>
#include <cinttypes>
@ -261,6 +262,14 @@ MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void);
} // extern "C"
#endif
// Get memory usage of the current model in bytes, per backend device
// Note: this is an unstable API, used internally by fit_params; it WILL be removed or changed without deprecation
#ifdef __cplusplus
MTMD_API std::map<ggml_backend_dev_t, size_t> mtmd_get_memory_usage(
const char * mmproj_fname,
struct mtmd_context_params ctx_params);
#endif
//
// C++ wrappers
//

View file

@ -746,6 +746,46 @@ private:
params_base = params;
std::string & mmproj_path = params_base.mmproj.path;
bool has_mmproj = !mmproj_path.empty();
mtmd_context_params mparams = mtmd_context_params_default();
if (has_mmproj) {
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.flash_attn_type = params_base.flash_attn_type;
mparams.warmup = params_base.warmup;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mparams.media_marker = get_media_marker();
}
// optionally get the memory usage of mmproj
if (has_mmproj && params_base.fit_params) {
auto mmproj_mem = mtmd_get_memory_usage(mmproj_path.c_str(), mparams);
if (!mmproj_mem.empty()) {
size_t total = 0;
for (auto & [dev, size] : mmproj_mem) {
total += size;
}
SRV_INF("[mtmd] estimated memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0));
GGML_ASSERT(!params_base.fit_params_target.empty());
for (auto & [dev, size] : mmproj_mem) {
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
if (ggml_backend_dev_get(i) == dev) {
if (i < params_base.fit_params_target.size()) {
SRV_DBG("[mtmd] adding %.2f MiB to fit_params_target for device %s\n", size / (1024.0 * 1024.0), ggml_backend_dev_name(dev));
params_base.fit_params_target[i] += size;
}
break;
}
}
}
} else {
SRV_ERR("%s", "[mtmd] failed to get memory usage of mmproj\n");
}
}
llama_init = common_init_from_params(params_base);
model_tgt = llama_init->model();
@ -830,18 +870,10 @@ private:
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.flash_attn_type = params_base.flash_attn_type;
mparams.warmup = params_base.warmup;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mparams.media_marker = get_media_marker();
if (has_mmproj) {
if (!is_resume) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
}
mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams);
if (mctx == nullptr) {