diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 8550b485e..2ccf5f184 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -107,6 +107,7 @@ enum projector_type { PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, + PROJECTOR_TYPE_QWEN2_5_VL, PROJECTOR_TYPE_UNKNOWN, }; @@ -117,6 +118,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_RESAMPLER, "resampler"}, { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, + { PROJECTOR_TYPE_QWEN2_5_VL,"qwen2.5vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 388696e0f..122aba633 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -780,6 +780,273 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i return gf; } +static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) { + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; + + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int patches_w = image_size_width / patch_size; + const int patches_h = image_size_height / patch_size; + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const float eps = hparams.eps; + const bool use_window_attn = hparams.full_attn_layers.size() > 0; + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + const int batch_size = imgs.entries.size(); + GGML_ASSERT(batch_size == 1); + + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(image_size_width % (patch_size * 2) == 0); + GGML_ASSERT(image_size_height % (patch_size * 2) == 0); + + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, patches_h, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); + inp = ggml_reshape_3d( + ctx0, inp, + hidden_size, patches_w * patches_h, batch_size); + + if (model.patch_bias) { + // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + } + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * window_mask = nullptr; + struct ggml_tensor * window_idx = nullptr; + struct ggml_tensor * inv_window_idx = nullptr; + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); + } + } + + std::vector embedding_stack; + const auto & vision_feature_layer = hparams.vision_feature_layer; + + if (use_window_attn) { + // handle window attention inputs + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + } + + // loop over layers + for (int il = 0; il < ctx->max_feature_layer; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + + // If this is an embedding feature layer, save the output. + // NOTE: 0 index here refers to the input to the encoder. + if (vision_feature_layer.find(il) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + + // rmsnorm1 + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w); + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); + + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_rope_multi( + ctx0, Q, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_rope_multi( + ctx0, K, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end(); + const bool full_attn = use_window_attn ? inlist : true; + if (full_attn) { + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + } else { + KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f); + } + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // rms norm2 + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w); + + // mlp + // ffn_up + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); + + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + if (ctx->use_gelu) { + cur_gate = ggml_gelu_inplace(ctx0, cur_gate); + } else if (ctx->use_silu) { + cur_gate = ggml_silu_inplace(ctx0, cur_gate); + } else { + cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate); + } + cur = ggml_mul(ctx0, cur_gate, cur_up); + + // ffn_down + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // post-layernorm + if (model.post_ln_w) { + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } + } + + // final layer is a vision feature layer + if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + + // If feature layers are explicitly set, stack them (if we have multiple) + if (!embedding_stack.empty()) { + embeddings = embedding_stack[0]; + for (size_t i = 1; i < embedding_stack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); + } + } + + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + // GELU activation + embeddings = ggml_gelu(ctx0, embeddings); + + // Second linear layer + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} + static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); @@ -1441,6 +1708,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = clip_image_build_graph_pixtral(ctx, imgs); } break; + case PROJECTOR_TYPE_QWEN2_5_VL: + { + res = clip_image_build_graph_qwen2_5_vl(ctx, imgs); + } break; default: { // TODO: we should have one build_* function per model @@ -1699,7 +1970,7 @@ struct clip_model_loader { // legacy naming (the in and out is reversed! don't ask me why) layer.ff_i_w = layer.ff_down_w; layer.ff_o_w = layer.ff_up_w; - layer.ff_g_w = layer.ff_gate_b; + layer.ff_g_w = layer.ff_gate_w; layer.ff_i_b = layer.ff_down_b; layer.ff_o_b = layer.ff_up_b; layer.ff_g_b = layer.ff_gate_b; @@ -1801,6 +2072,7 @@ struct clip_model_loader { vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); } break; case PROJECTOR_TYPE_MERGER: + case PROJECTOR_TYPE_QWEN2_5_VL: { vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); @@ -2754,7 +3026,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 4) { n_patches = 64; } - } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER || ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); @@ -3165,7 +3437,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn? + if (hparams.attn_window_size > 0 && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) { struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); @@ -3398,6 +3670,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GLM_EDGE: return ctx->vision_model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_MERGER: + case PROJECTOR_TYPE_QWEN2_5_VL: return ctx->vision_model.mm_1_b->ne[0]; case PROJECTOR_TYPE_GEMMA3: return ctx->vision_model.mm_input_proj_w->ne[0]; diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index 0a47a719f..2345c10d1 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -140,7 +140,6 @@ def main(args): fout.add_bool("clip.has_text_encoder", False) fout.add_bool("clip.has_vision_encoder", True) fout.add_bool("clip.has_qwen2vl_merger", True) - fout.add_string("clip.projector_type", "qwen2vl_merger") print(cfg.vision_config) if 'silu' in cfg.vision_config.hidden_act.lower(): @@ -159,7 +158,9 @@ def main(args): fout.add_uint32("clip.vision.window_size", vcfg.window_size) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size) + fout.add_string("clip.projector_type", "qwen2.5vl_merger") else: + fout.add_string("clip.projector_type", "qwen2vl_merger") fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)