Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build-linux-cross.yml
#	.github/workflows/build.yml
#	cmake/build-info.cmake
#	common/CMakeLists.txt
#	examples/llava/README.md
#	examples/server/README.md
#	ggml/CMakeLists.txt
#	ggml/src/ggml-cuda/CMakeLists.txt
#	ggml/src/ggml-rpc/ggml-rpc.cpp
#	ggml/src/ggml-vulkan/CMakeLists.txt
#	ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
#	tests/test-chat-template.cpp
This commit is contained in:
Concedo 2025-05-02 16:54:15 +08:00
commit d8f1f73dd7
25 changed files with 522 additions and 126 deletions

View file

@ -31,6 +31,7 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
@ -68,9 +69,11 @@
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
// mimicpmv

View file

@ -186,6 +186,7 @@ struct clip_hparams {
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
};
struct clip_layer {
@ -246,6 +247,7 @@ struct clip_vision_model {
struct ggml_tensor * projection;
// LLaVA projection
struct ggml_tensor * mm_input_norm_w = nullptr;
struct ggml_tensor * mm_0_w = nullptr;
struct ggml_tensor * mm_0_b = nullptr;
struct ggml_tensor * mm_2_w = nullptr;
@ -325,6 +327,7 @@ struct clip_vision_model {
// pixtral
struct ggml_tensor * token_embd_img_break = nullptr;
struct ggml_tensor * mm_patch_merger_w = nullptr;
};
bool enable_gpu_clip = true;
@ -662,6 +665,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
const int n_merge = hparams.spatial_merge_size;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
@ -746,7 +750,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
{
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
if (ctx->use_silu) {
gate_proj = ggml_silu(ctx0, gate_proj);
} else if (ctx->use_gelu) {
gate_proj = ggml_gelu(ctx0, gate_proj);
} else {
GGML_ABORT("Pixtral: Unsupported activation");
}
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
}
@ -757,14 +767,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
embeddings = cur;
}
// LlavaMultiModalProjector (with GELU activation)
// mistral small 3.1 patch merger
// ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
if (model.mm_patch_merger_w) {
GGML_ASSERT(hparams.spatial_merge_size > 0);
ggml_tensor * cur = embeddings;
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
// reshape image tokens to 2D grid
cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
cur = ggml_cont(ctx0, cur);
// torch.nn.functional.unfold is just an im2col under the hood
// we just need a dummy kernel to make it work
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
// project to hidden_size
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
embeddings = cur;
}
// LlavaMultiModalProjector (always using GELU activation)
{
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
if (model.mm_1_b) {
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
}
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
if (model.mm_2_b) {
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
}
// arrangement of the [IMG_BREAK] token
@ -774,11 +812,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
const int p_total = p_x * p_y;
const int n_embd_text = embeddings->ne[0];
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1);
@ -1780,6 +1821,7 @@ struct clip_model_loader {
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_QWEN25VL:
{
@ -2007,11 +2049,14 @@ struct clip_model_loader {
case PROJECTOR_TYPE_PIXTRAL:
{
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
// [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
// for mistral small 3.1
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
@ -2653,7 +2698,7 @@ struct llava_uhd {
// no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
auto best_size = get_best_resize(original_size, slice_size, patch_size, has_slices);
auto best_size = get_best_resize(original_size, slice_size, patch_size, !has_slices);
res.overview_size = best_size;
if (!has_slices) {
@ -3067,8 +3112,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_patches_x = img->nx / params.patch_size;
int n_patches_y = img->ny / params.patch_size;
int n_merge = ctx->vision_model.hparams.spatial_merge_size;
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}
@ -3654,7 +3700,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0];
return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV:

View file

@ -72,6 +72,8 @@ struct mtmd_cli_context {
llama_batch batch;
int n_batch;
std::vector<mtmd_bitmap> bitmaps;
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
// so here we don't need to keep track of chat history
common_chat_templates_ptr tmpls;
@ -94,6 +96,7 @@ struct mtmd_cli_context {
LOG_ERR("Model does not have chat template.\n");
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n");
exit(1);
}
@ -134,13 +137,22 @@ struct mtmd_cli_context {
antiprompt_tokens.begin()
);
}
bool load_image(const std::string & fname) {
mtmd_bitmap bitmap;
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
return false;
}
bitmaps.push_back(std::move(bitmap));
return true;
}
};
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
llama_tokens generated_tokens;
for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating || g_is_interrupted) {
printf("\n");
LOG("\n");
break;
}
@ -149,15 +161,15 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
common_sampler_accept(smpl, token_id, true);
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
printf("\n");
LOG("\n");
break; // end of generation
}
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout);
if (g_is_interrupted) {
printf("\n");
LOG("\n");
break;
}
@ -172,9 +184,7 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
return 0;
}
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
std::vector<mtmd_bitmap> bitmaps;
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
common_chat_templates_inputs tmpl_inputs;
tmpl_inputs.messages = {msg};
tmpl_inputs.add_generation_prompt = true;
@ -182,15 +192,6 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
for (auto & fname : images_fname) {
mtmd_bitmap bitmap;
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
LOG_ERR("Unable to load image %s\n", fname.c_str());
return 2; // image not found
}
bitmaps.push_back(std::move(bitmap));
}
mtmd_input_text text;
text.text = formatted_chat.prompt;
text.add_special = add_bos;
@ -199,12 +200,14 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
if (g_is_interrupted) return 0;
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
if (res != 0) {
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
return 1;
}
ctx.bitmaps.clear();
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
LOG_ERR("Unable to eval prompt\n");
return 1;
@ -212,6 +215,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
ctx.n_past += mtmd_helper_get_n_pos(chunks);
LOG("\n");
return 0;
}
@ -234,7 +239,7 @@ int main(int argc, char ** argv) {
}
mtmd_cli_context ctx(params);
printf("%s: %s\n", __func__, params.model.path.c_str());
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
@ -267,7 +272,12 @@ int main(int argc, char ** argv) {
common_chat_msg msg;
msg.role = "user";
msg.content = params.prompt;
if (eval_message(ctx, msg, params.image, true)) {
for (const auto & image : params.image) {
if (!ctx.load_image(image)) {
return 1; // error is already printed by libmtmd
}
}
if (eval_message(ctx, msg, true)) {
return 1;
}
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
@ -282,7 +292,6 @@ int main(int argc, char ** argv) {
LOG("\n");
bool is_first_msg = true;
std::vector<std::string> images_fname;
std::string content;
while (!g_is_interrupted) {
@ -307,10 +316,17 @@ int main(int argc, char ** argv) {
continue;
}
g_is_generating = true;
if (line.find("/image") == 0) {
if (line == "/image" || line.find("/image ") == 0) {
if (line.size() < 8) {
LOG_ERR("ERR: Missing image filename\n");
continue;
}
std::string image = line.substr(7);
images_fname.push_back(string_strip(image));
content += "<__image__>";
if (ctx.load_image(image)) {
LOG("Image %s loaded\n", image.c_str());
content += "<__image__>";
}
// else, error is already printed by libmtmd
continue;
} else {
content += line;
@ -318,21 +334,14 @@ int main(int argc, char ** argv) {
common_chat_msg msg;
msg.role = "user";
msg.content = content;
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
if (g_is_interrupted) break;
if (ret == 2) {
// non-fatal error
images_fname.clear();
content.clear();
continue;
}
int ret = eval_message(ctx, msg, is_first_msg);
if (ret) {
return 1;
}
if (g_is_interrupted) break;
if (generate_response(ctx, smpl, n_predict)) {
return 1;
}
images_fname.clear();
content.clear();
is_first_msg = false;
}

View file

@ -590,7 +590,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
}
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported");
GGML_ASSERT(chunk.tokens_image != nullptr);
int64_t t0 = ggml_time_ms();
if (ctx->print_timings) {

View file

@ -59,6 +59,7 @@ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
# these models always give the wrong answer, not sure why
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"