replace KEY_FULLATTN_BLK_IDX with KEY_WIN_ATTN_PATTERN

This commit is contained in:
HimariO 2025-04-26 01:00:00 +08:00
parent f69e9fa04d
commit 77b144a8e7
3 changed files with 27 additions and 12 deletions

View file

@ -46,7 +46,7 @@
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"

View file

@ -171,7 +171,7 @@ struct clip_hparams {
int32_t image_crop_resolution; int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer; std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size; int32_t attn_window_size;
std::vector<int32_t> full_attn_layers; int32_t n_wa_pattern;
}; };
struct clip_layer { struct clip_layer {
@ -799,7 +799,8 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
const float eps = hparams.eps; const float eps = hparams.eps;
const bool use_window_attn = hparams.full_attn_layers.size() > 0; const int n_wa_pattern = hparams.n_wa_pattern;
const bool use_window_attn = hparams.n_wa_pattern > 0;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs.entries.size(); const int batch_size = imgs.entries.size();
@ -926,8 +927,7 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end(); const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
const bool full_attn = use_window_attn ? inlist : true;
if (full_attn) { if (full_attn) {
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
} else { } else {
@ -1721,8 +1721,8 @@ struct clip_model_loader {
get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false); get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false);
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false);
{ {
std::string mm_patch_merge_type; std::string mm_patch_merge_type;
@ -3074,6 +3074,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
bool support_dynamic_size = ctx->has_minicpmv_projector bool support_dynamic_size = ctx->has_minicpmv_projector
|| ctx->has_qwen2vl_merger || ctx->has_qwen2vl_merger
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL; || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
const bool use_window_attn = hparams.n_wa_pattern > 0;
const int image_size = hparams.image_size; const int image_size = hparams.image_size;
int image_size_width = image_size; int image_size_width = image_size;
@ -3335,7 +3336,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
} }
if (hparams.attn_window_size > 0 && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) { if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
@ -3388,9 +3389,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
} }
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
} }
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);

View file

@ -1,5 +1,5 @@
import argparse import argparse
from typing import Dict from typing import Dict, List, Optional
import torch import torch
import numpy as np import numpy as np
@ -20,6 +20,20 @@ VISION = "clip.vision"
def k(raw_key: str, arch: str) -> str: def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch) return raw_key.format(arch=arch)
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
if fullatt_block_indexes is None:
return 0
n_wa = fullatt_block_indexes[0]
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
if b - a - 1 != n_wa:
raise ValueError(
f"window/full attention layer should have fix pattern of "
f"for each full-attention layer followed by {n_wa} window-attention layers"
)
return n_wa + 1
class VL2: class VL2:
@staticmethod @staticmethod
@ -152,7 +166,7 @@ def main(args):
raise ValueError() raise ValueError()
if args.model_type == "qwen2.5vl": if args.model_type == "qwen2.5vl":
fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes) fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
fout.add_uint32("clip.vision.window_size", vcfg.window_size) fout.add_uint32("clip.vision.window_size", vcfg.window_size)
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size) fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)