diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 9bc17af2a..01929a52a 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -39,6 +39,7 @@ #include #include #include +#include #if defined(LLAVA_LOG_OFF) # define LOG_INF(...) @@ -102,6 +103,8 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_SILU "clip.use_silu" +#define KEY_USE_GLU_MLP "clip.use_glu_mlp" +#define KEY_USE_RMS_NORM "clip.use_rms_norm" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_BLOCK "clip.%s.block_count" @@ -120,6 +123,8 @@ static std::string format(const char * fmt, ...) { #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" // @@ -138,6 +143,7 @@ static std::string format(const char * fmt, ...) { #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" +#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" @@ -447,6 +453,8 @@ struct clip_hparams { std::vector image_grid_pinpoints; int32_t image_crop_resolution; std::unordered_set vision_feature_layer; + int32_t attn_window_size; + std::vector full_attn_layers; }; struct clip_layer { @@ -472,6 +480,9 @@ struct clip_layer { struct ggml_tensor * ff_o_w; struct ggml_tensor * ff_o_b; + struct ggml_tensor * ff_g_w = NULL; + struct ggml_tensor * ff_g_b = NULL; + // layernorm 2 struct ggml_tensor * ln_2_w; struct ggml_tensor * ln_2_b; @@ -601,6 +612,8 @@ struct clip_ctx { float image_std[3]; bool use_gelu = false; bool use_silu = false; + bool use_glu_mlp = false; + bool use_rms_norm = false; int32_t ftype = 1; bool has_class_embedding = true; @@ -856,6 +869,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; const float eps = hparams.eps; + const bool use_window_attn = hparams.full_attn_layers.size() > 0; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs->size; @@ -906,8 +920,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); } - struct ggml_tensor * embeddings = inp; - struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * pos_embed = nullptr; + struct ggml_tensor * window_mask = nullptr; + struct ggml_tensor * window_idx = nullptr; + struct ggml_tensor * inv_window_idx = nullptr; if (ctx->has_llava_projector) { // concat class_embeddings and patch_embeddings @@ -949,16 +966,41 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // pre-layernorm if (ctx->has_pre_norm) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "pre_ln"); + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); + embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); + } } std::vector embedding_stack; const auto & vision_feature_layer = hparams.vision_feature_layer; // loop over layers + + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + + } + for (int il = 0; il < ctx->max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states @@ -971,9 +1013,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im //const size_t nb_q_w = model.layers[il].q_w->nb[0]; // layernorm1 - { + if (ctx->use_rms_norm) { + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w); + } + else { cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); } @@ -1014,7 +1059,15 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_inplace(ctx0, KQ); + const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end(); + const bool full_attn = use_window_attn ? inlist : true; + if (full_attn) { + KQ = ggml_soft_max_inplace(ctx0, KQ); + } else { + KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f); + + } + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1031,25 +1084,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = cur; // embeddings = residual, cur = hidden_states // layernorm2 - { + if (ctx->use_rms_norm) { + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w); + } else { cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + // mlp + if (ctx->use_glu_mlp) { + // ffn_up + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); - if (ctx->use_gelu) { - cur = ggml_gelu_inplace(ctx0, cur); - } else if (ctx->use_silu) { - cur = ggml_silu_inplace(ctx0, cur); - } else { - cur = ggml_gelu_quick_inplace(ctx0, cur); + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + if (ctx->use_gelu) { + cur_gate = ggml_gelu_inplace(ctx0, cur_gate); + } else if (ctx->use_silu) { + cur_gate = ggml_silu_inplace(ctx0, cur_gate); + } else { + cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate); + } + cur = ggml_mul(ctx0, cur_gate, cur_up); + + // ffn_down + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); } + else { + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + if (ctx->use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else if (ctx->use_silu) { + cur = ggml_silu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + } // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -1059,10 +1137,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // post-layernorm if (ctx->has_post_norm) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "post_ln"); + if (ctx->use_rms_norm) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); + } else { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } } // final layer is a vision feature layer @@ -1375,6 +1460,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); } + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); + } + // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1569,6 +1666,20 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p new_clip->use_silu = false; } + try { + idx = get_key_idx(ctx, KEY_USE_GLU_MLP); + new_clip->use_glu_mlp = gguf_get_val_bool(ctx, idx); + } catch (std::runtime_error & /*e*/) { + new_clip->use_glu_mlp = false; + } + + try { + idx = get_key_idx(ctx, KEY_USE_RMS_NORM); + new_clip->use_rms_norm = gguf_get_val_bool(ctx, idx); + } catch (std::runtime_error & /*e*/) { + new_clip->use_rms_norm = false; + } + if (verbosity >= 1) { LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); @@ -1703,6 +1814,18 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean); const float * std_data = (const float *)gguf_get_arr_data(ctx, idx_std); + try { + int idx_full_attn_layers = get_key_idx(ctx, KEY_FULLATTN_BLK_IDX); + auto n_full_attn_layers = gguf_get_arr_n(ctx, idx_full_attn_layers); + const int * full_attn_layers = (const int *)gguf_get_arr_data(ctx, idx_full_attn_layers); + hparams.full_attn_layers.assign(full_attn_layers, full_attn_layers + n_full_attn_layers); + + int idx_window_size = get_key_idx(ctx, KEY_ATTN_WINDOW_SIZE); + hparams.attn_window_size = gguf_get_val_u32(ctx, idx_window_size); + } catch (std::runtime_error & /*e*/) { + hparams.attn_window_size = 0; + } + for (int i = 0; i < 3; ++i) { new_clip->image_mean[i] = mean_data[i]; new_clip->image_std[i] = std_data[i]; @@ -1753,8 +1876,15 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p } try { - vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias")); + vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); + new_clip->has_post_norm = true; + } catch (std::exception & /*e*/) { + new_clip->has_post_norm = false; + } + try { + // in case of rms norm, there will be only ln weight + vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); new_clip->has_post_norm = true; } catch (std::exception & /*e*/) { new_clip->has_post_norm = false; @@ -1914,10 +2044,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p layer.q_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "bias")); layer.v_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "bias")); layer.o_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias")); - layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias")); - layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias")); layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "bias")); layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "bias")); + + if (!new_clip->use_rms_norm) { + layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias")); + layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias")); + } + if (new_clip->use_glu_mlp) { + layer.ff_g_w = get_tensor(new_clip->ctx_data, format(TN_FFN_GATE, "v", il, "weight")); + layer.ff_g_b = get_tensor(new_clip->ctx_data, format(TN_FFN_GATE, "v", il, "bias")); + } } } @@ -3024,30 +3161,96 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } if (ctx->has_qwen2vl_merger) { + /* + pw * ph = number of tokens output by ViT after apply patch merger + ipw * ipw = number of vision token been processed inside ViT + */ + const int merge_ratio = 2; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + + std::vector idx(ph * pw); + std::vector inv_idx(ph * pw); + + if (hparams.attn_window_size > 0) { + struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); + struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); + struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); + + const int grid_window = hparams.attn_window_size / patch_size / merge_ratio; + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y+=grid_window) + { + for (int x = 0; x < pw; x+=grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx[src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + } else { + std::iota(idx.begin(), idx.end(), 0); + std::iota(inv_idx.begin(), inv_idx.end(), 0); + } + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - const int pw = image_size_width / patch_size; - const int ph = image_size_height / patch_size; + // const int pw = image_size_width / patch_size; + // const int ph = image_size_height / patch_size; + const int mpow = (merge_ratio * merge_ratio); int* positions_data = (int*)malloc(ggml_nbytes(positions)); int ptr = 0; - for (int y = 0; y < ph; y+=2) + for (int y = 0; y < iph; y+=merge_ratio) { - for (int x = 0; x < pw; x+=2) + for (int x = 0; x < ipw; x+=merge_ratio) { for (int dy = 0; dy < 2; dy++) { for (int dx = 0; dx < 2; dx++) { - positions_data[ptr] = y + dy; - positions_data[num_patches + ptr] = x + dx; - positions_data[num_patches * 2 + ptr] = y + dy; - positions_data[num_patches * 3 + ptr] = x + dx; + auto remap = idx[ptr / mpow]; + remap = remap * mpow + (ptr % mpow); + + positions_data[remap] = y + dy; + positions_data[num_patches + remap] = x + dx; + positions_data[num_patches * 2 + remap] = y + dy; + positions_data[num_patches * 3 + remap] = x + dx; ptr++; } } } } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + if (positions) ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { @@ -3079,6 +3282,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn? + struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); + struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); + struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); + + const int merge_ratio = 2; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int grid_window = hparams.attn_window_size / patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + /* + pw * ph = number of tokens output by ViT after apply patch merger + ipw * ipw = number of vision token been processed inside ViT + */ + + std::vector idx(ph * pw); + std::vector inv_idx(ph * pw); + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y+=grid_window) + { + for (int x = 0; x < pw; x+=grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx[src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + + if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + } + if (ggml_backend_is_cpu(ctx->backend)) { ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); }