replace KEY_FULLATTN_BLK_IDX with KEY_WIN_ATTN_PATTERN

This commit is contained in:
HimariO 2025-04-26 01:00:00 +08:00
parent f69e9fa04d
commit 77b144a8e7
3 changed files with 27 additions and 12 deletions

View file

@ -171,7 +171,7 @@ struct clip_hparams {
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size;
std::vector<int32_t> full_attn_layers;
int32_t n_wa_pattern;
};
struct clip_layer {
@ -799,7 +799,8 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const float eps = hparams.eps;
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
const int n_wa_pattern = hparams.n_wa_pattern;
const bool use_window_attn = hparams.n_wa_pattern > 0;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs.entries.size();
@ -926,8 +927,7 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
const bool full_attn = use_window_attn ? inlist : true;
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
if (full_attn) {
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
} else {
@ -1721,8 +1721,8 @@ struct clip_model_loader {
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false);
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false);
{
std::string mm_patch_merge_type;
@ -3074,6 +3074,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
bool support_dynamic_size = ctx->has_minicpmv_projector
|| ctx->has_qwen2vl_merger
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
const bool use_window_attn = hparams.n_wa_pattern > 0;
const int image_size = hparams.image_size;
int image_size_width = image_size;
@ -3335,7 +3336,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
if (hparams.attn_window_size > 0 && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
@ -3388,9 +3389,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
}
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);