diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 2ccf5f184..312fe181c 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -46,7 +46,7 @@ #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" -#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes" +#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 79ddb1b33..859e1da77 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -171,7 +171,7 @@ struct clip_hparams { int32_t image_crop_resolution; std::unordered_set vision_feature_layer; int32_t attn_window_size; - std::vector full_attn_layers; + int32_t n_wa_pattern; }; struct clip_layer { @@ -799,7 +799,8 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; const float eps = hparams.eps; - const bool use_window_attn = hparams.full_attn_layers.size() > 0; + const int n_wa_pattern = hparams.n_wa_pattern; + const bool use_window_attn = hparams.n_wa_pattern > 0; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs.entries.size(); @@ -926,8 +927,7 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end(); - const bool full_attn = use_window_attn ? inlist : true; + const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; if (full_attn) { KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); } else { @@ -1721,8 +1721,8 @@ struct clip_model_loader { get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false); + get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); - get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false); { std::string mm_patch_merge_type; @@ -3074,6 +3074,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima bool support_dynamic_size = ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL; + const bool use_window_attn = hparams.n_wa_pattern > 0; const int image_size = hparams.image_size; int image_size_width = image_size; @@ -3335,7 +3336,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - if (hparams.attn_window_size > 0 && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) { + if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) { struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); @@ -3388,9 +3389,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } - if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); - if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); - if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); } ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index f80f25f92..dbc342c82 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -1,5 +1,5 @@ import argparse -from typing import Dict +from typing import Dict, List, Optional import torch import numpy as np @@ -20,6 +20,20 @@ VISION = "clip.vision" def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) + +def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]): + if fullatt_block_indexes is None: + return 0 + n_wa = fullatt_block_indexes[0] + for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]): + if b - a - 1 != n_wa: + raise ValueError( + f"window/full attention layer should have fix pattern of " + f"for each full-attention layer followed by {n_wa} window-attention layers" + ) + return n_wa + 1 + + class VL2: @staticmethod @@ -152,7 +166,7 @@ def main(args): raise ValueError() if args.model_type == "qwen2.5vl": - fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes) + fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes)) fout.add_uint32("clip.vision.window_size", vcfg.window_size) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)