merged qwen2.5vl again

This commit is contained in:
Concedo 2025-04-08 21:32:25 +08:00
commit 88660dd59d
4 changed files with 439 additions and 117 deletions

View file

@ -22,6 +22,8 @@
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu" #define KEY_USE_SILU "clip.use_silu"
#define KEY_USE_GLU_MLP "clip.use_glu_mlp"
#define KEY_USE_RMS_NORM "clip.use_rms_norm"
#define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count" #define KEY_N_BLOCK "clip.%s.block_count"
@ -40,6 +42,8 @@
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
// //
@ -58,6 +62,7 @@
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_PRE "%s.pre_ln.%s"

View file

@ -40,6 +40,7 @@
#include <sstream> #include <sstream>
#include <cinttypes> #include <cinttypes>
#include <limits> #include <limits>
#include <numeric>
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
@ -196,6 +197,8 @@ struct clip_hparams {
std::vector<int32_t> image_grid_pinpoints; std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution; int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer; std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size;
std::vector<int32_t> full_attn_layers;
}; };
struct clip_layer { struct clip_layer {
@ -221,6 +224,9 @@ struct clip_layer {
struct ggml_tensor * ff_o_w = nullptr; struct ggml_tensor * ff_o_w = nullptr;
struct ggml_tensor * ff_o_b = nullptr; struct ggml_tensor * ff_o_b = nullptr;
struct ggml_tensor * ff_g_w = NULL;
struct ggml_tensor * ff_g_b = NULL;
// layernorm 2 // layernorm 2
struct ggml_tensor * ln_2_w = nullptr; struct ggml_tensor * ln_2_w = nullptr;
struct ggml_tensor * ln_2_b = nullptr; struct ggml_tensor * ln_2_b = nullptr;
@ -350,6 +356,8 @@ struct clip_ctx {
float image_std[3]; float image_std[3];
bool use_gelu = false; bool use_gelu = false;
bool use_silu = false; bool use_silu = false;
bool use_glu_mlp = false;
bool use_rms_norm = false;
int32_t ftype = 1; int32_t ftype = 1;
struct gguf_context * ctx_gguf = nullptr; struct gguf_context * ctx_gguf = nullptr;
@ -599,6 +607,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
const float eps = hparams.eps; const float eps = hparams.eps;
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs->size; const int batch_size = imgs->size;
@ -628,6 +637,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_add(ctx0, inp, inp_1); inp = ggml_add(ctx0, inp, inp_1);
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
inp = ggml_reshape_4d( inp = ggml_reshape_4d(
ctx0, inp, ctx0, inp,
@ -649,8 +659,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
inp = ggml_add(ctx0, inp, model.patch_bias); inp = ggml_add(ctx0, inp, model.patch_bias);
} }
struct ggml_tensor * embeddings = inp; struct ggml_tensor * embeddings = inp;
struct ggml_tensor * pos_embed = nullptr; struct ggml_tensor * pos_embed = nullptr;
struct ggml_tensor * window_mask = nullptr;
struct ggml_tensor * window_idx = nullptr;
struct ggml_tensor * inv_window_idx = nullptr;
if (ctx->has_llava_projector) { if (ctx->has_llava_projector) {
// concat class_embeddings and patch_embeddings // concat class_embeddings and patch_embeddings
@ -692,16 +705,40 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
// pre-layernorm // pre-layernorm
if (model.pre_ln_w) { if (model.pre_ln_w) {
embeddings = ggml_norm(ctx0, embeddings, eps); if (ctx->use_rms_norm) {
ggml_set_name(embeddings, "pre_ln"); embeddings = ggml_rms_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
} else {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
}
} }
std::vector<struct ggml_tensor *> embedding_stack; std::vector<struct ggml_tensor *> embedding_stack;
const auto & vision_feature_layer = hparams.vision_feature_layer; const auto & vision_feature_layer = hparams.vision_feature_layer;
// loop over layers // loop over layers
if (use_window_attn) {
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
ggml_set_name(inv_window_idx, "inv_window_idx");
ggml_set_input(inv_window_idx);
// mask for window attention
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
ggml_set_name(window_mask, "window_mask");
ggml_set_input(window_mask);
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
}
for (int il = 0; il < ctx->max_feature_layer; il++) { for (int il = 0; il < ctx->max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
@ -714,9 +751,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
//const size_t nb_q_w = model.layers[il].q_w->nb[0]; //const size_t nb_q_w = model.layers[il].q_w->nb[0];
// layernorm1 // layernorm1
{ if (ctx->use_rms_norm) {
cur = ggml_rms_norm(ctx0, cur, eps);
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
}
else {
cur = ggml_norm(ctx0, cur, eps); cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
model.layers[il].ln_1_b); model.layers[il].ln_1_b);
} }
@ -756,7 +796,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
const bool full_attn = use_window_attn ? inlist : true;
if (full_attn) {
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
} else {
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
}
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -773,25 +820,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
embeddings = cur; // embeddings = residual, cur = hidden_states embeddings = cur; // embeddings = residual, cur = hidden_states
// layernorm2 // layernorm2
{ if (ctx->use_rms_norm) {
cur = ggml_rms_norm(ctx0, cur, eps);
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
} else {
cur = ggml_norm(ctx0, cur, eps); cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
} }
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); // mlp
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); if (ctx->use_glu_mlp) {
// ffn_up
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
if (ctx->use_gelu) { auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
cur = ggml_gelu_inplace(ctx0, cur); cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
} else if (ctx->use_silu) { if (ctx->use_gelu) {
cur = ggml_silu_inplace(ctx0, cur); cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
} else { } else if (ctx->use_silu) {
cur = ggml_gelu_quick_inplace(ctx0, cur); cur_gate = ggml_silu_inplace(ctx0, cur_gate);
} else {
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
}
cur = ggml_mul(ctx0, cur_gate, cur_up);
// ffn_down
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
} }
else {
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); if (ctx->use_gelu) {
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); cur = ggml_gelu_inplace(ctx0, cur);
} else if (ctx->use_silu) {
cur = ggml_silu_inplace(ctx0, cur);
} else {
cur = ggml_gelu_quick_inplace(ctx0, cur);
}
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
}
// residual 2 // residual 2
cur = ggml_add(ctx0, embeddings, cur); cur = ggml_add(ctx0, embeddings, cur);
@ -801,10 +873,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
// post-layernorm // post-layernorm
if (model.post_ln_w) { if (model.post_ln_w) {
embeddings = ggml_norm(ctx0, embeddings, eps); if (ctx->use_rms_norm) {
ggml_set_name(embeddings, "post_ln"); embeddings = ggml_rms_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
} else {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
} }
// final layer is a vision feature layer // final layer is a vision feature layer
@ -1118,6 +1197,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
} }
if (use_window_attn) {
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
ggml_set_name(window_idx, "window_idx");
ggml_set_input(window_idx);
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
}
// build the graph // build the graph
ggml_build_forward_expand(gf, embeddings); ggml_build_forward_expand(gf, embeddings);
@ -1228,6 +1319,8 @@ struct clip_model_loader {
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
get_bool(KEY_USE_GLU_MLP, ctx_clip.use_glu_mlp, false);
get_bool(KEY_USE_RMS_NORM, ctx_clip.use_rms_norm, false);
auto & hparams = ctx_clip.vision_model.hparams; auto & hparams = ctx_clip.vision_model.hparams;
get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
@ -1239,7 +1332,9 @@ struct clip_model_loader {
get_u32(KEY_IMAGE_SIZE, hparams.image_size); get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false);
{ {
std::string mm_patch_merge_type; std::string mm_patch_merge_type;
@ -1355,14 +1450,16 @@ struct clip_model_loader {
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_g_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), ctx_clip.use_glu_mlp);
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), !ctx_clip.use_rms_norm);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), !ctx_clip.use_rms_norm);
layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_g_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), ctx_clip.use_glu_mlp);
} }
switch (ctx_clip.proj_type) { switch (ctx_clip.proj_type) {
@ -2631,6 +2728,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
free(data); free(data);
} }
if (ctx->has_minicpmv_projector) { if (ctx->has_minicpmv_projector) {
{ {
// inspired from siglip: // inspired from siglip:
@ -2694,23 +2792,86 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
if (ctx->has_qwen2vl_merger) { if (ctx->has_qwen2vl_merger) {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); /*
pw * ph = number of tokens output by ViT after apply patch merger
ipw * ipw = number of vision token been processed inside ViT
*/
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
const int pw = image_size_width / patch_size; std::vector<int> idx(ph * pw);
const int ph = image_size_height / patch_size; std::vector<int> inv_idx(ph * pw);
if (hparams.attn_window_size > 0) {
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y+=grid_window)
{
for (int x = 0; x < pw; x+=grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
assert(src < (int)idx.size());
assert(dst < (int)inv_idx.size());
idx[src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
} else {
std::iota(idx.begin(), idx.end(), 0);
std::iota(inv_idx.begin(), inv_idx.end(), 0);
}
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
const int mpow = (merge_ratio * merge_ratio);
int* positions_data = (int*)malloc(ggml_nbytes(positions)); int* positions_data = (int*)malloc(ggml_nbytes(positions));
int ptr = 0; int ptr = 0;
for (int y = 0; y < ph; y+=2) for (int y = 0; y < iph; y+=merge_ratio)
{ {
for (int x = 0; x < pw; x+=2) for (int x = 0; x < ipw; x+=merge_ratio)
{ {
for (int dy = 0; dy < 2; dy++) { for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) { for (int dx = 0; dx < 2; dx++) {
positions_data[ptr] = y + dy; auto remap = idx[ptr / mpow];
positions_data[num_patches + ptr] = x + dx; remap = remap * mpow + (ptr % mpow);
positions_data[num_patches * 2 + ptr] = y + dy;
positions_data[num_patches * 3 + ptr] = x + dx; positions_data[remap] = y + dy;
positions_data[num_patches + remap] = x + dx;
positions_data[num_patches * 2 + remap] = y + dy;
positions_data[num_patches * 3 + remap] = x + dx;
ptr++; ptr++;
} }
} }
@ -2749,6 +2910,64 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} }
} }
if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn?
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
/*
pw * ph = number of tokens output by ViT after apply patch merger
ipw * ipw = number of vision token been processed inside ViT
*/
std::vector<int> idx(ph * pw);
std::vector<int> inv_idx(ph * pw);
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y+=grid_window)
{
for (int x = 0; x < pw; x+=grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
assert(src < (int)idx.size());
assert(dst < (int)inv_idx.size());
idx[src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
}
if (ggml_backend_is_cpu(ctx->backend)) { if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
} }

View file

@ -5,10 +5,12 @@ import torch
import numpy as np import numpy as np
from gguf import * from gguf import *
from transformers import ( from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2VLProcessor,
AutoProcessor, AutoProcessor,
Qwen2VLConfig Qwen2VLConfig,
Qwen2VLProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
) )
@ -18,62 +20,80 @@ VISION = "clip.vision"
def k(raw_key: str, arch: str) -> str: def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch) return raw_key.format(arch=arch)
class VL2:
def to_gguf_name(name: str) -> str: @staticmethod
og = name def to_gguf_name(name: str) -> str:
name = name.replace("text_model", "t").replace("vision_model", "v") og = name
name = name.replace("blocks", "blk").replace("embeddings.", "") name = name.replace("text_model", "t").replace("vision_model", "v")
name = name.replace("attn.", "attn_") name = name.replace("blocks", "blk").replace("embeddings.", "")
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") name = name.replace("attn.", "attn_")
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
name = name.replace("norm1", "ln1").replace("norm2", "ln2") # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
name = name.replace("merger.mlp", 'mm') name = name.replace("norm1", "ln1").replace("norm2", "ln2")
print(f"[to_gguf_name] {og} --> {name}") name = name.replace("merger.mlp", 'mm')
return name print(f"[to_gguf_name] {og} --> {name}")
return name
@classmethod
def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
vision_model = qwen2vl.visual vision_model = qwen2vl.visual
tensor_map = {} tensor_map = {}
for name, ten in vision_model.state_dict().items(): for name, ten in vision_model.state_dict().items():
ten = ten.numpy() ten = ten.numpy()
if 'qkv' in name: if 'qkv' in name:
if ten.ndim == 2: # weight if ten.ndim == 2: # weight
c3, _ = ten.shape c3, _ = ten.shape
else: # bias else: # bias
c3 = ten.shape[0] c3 = ten.shape[0]
assert c3 % 3 == 0 assert c3 % 3 == 0
c = c3 // 3 c = c3 // 3
wq = ten[:c] wq = ten[:c]
wk = ten[c: c * 2] wk = ten[c: c * 2]
wv = ten[c * 2:] wv = ten[c * 2:]
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
elif 'merger' in name: elif 'merger' in name:
if name.endswith("ln_q.weight"): if name.endswith("ln_q.weight"):
tensor_map['v.post_ln.weight'] = ten tensor_map['v.post_ln.weight'] = ten
elif name.endswith("ln_q.bias"): elif name.endswith("ln_q.bias"):
tensor_map['v.post_ln.bias'] = ten tensor_map['v.post_ln.bias'] = ten
else:
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
tensor_map[cls.to_gguf_name(name)] = ten
elif 'patch_embed.proj.weight' in name:
# NOTE: split Conv3D into Conv2Ds
c1, c2, kt, kh, kw = ten.shape
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
else: else:
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
tensor_map[to_gguf_name(name)] = ten
elif 'patch_embed.proj.weight' in name:
# NOTE: split Conv3D into Conv2Ds
c1, c2, kt, kh, kw = ten.shape
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
else:
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
for new_name, ten in tensor_map.items(): for new_name, ten in tensor_map.items():
if ten.ndim <= 1 or new_name.endswith("_norm.weight"): if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
tensor_map[new_name] = ten.astype(np.float32) tensor_map[new_name] = ten.astype(np.float32)
else: else:
tensor_map[new_name] = ten.astype(dtype) tensor_map[new_name] = ten.astype(dtype)
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
return tensor_map return tensor_map
class VL25(VL2):
@staticmethod
def to_gguf_name(name: str) -> str:
og = name
name = name.replace("text_model", "t").replace("vision_model", "v")
name = name.replace("blocks", "blk").replace("embeddings.", "")
name = name.replace("attn.", "attn_")
name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
name = name.replace("merger.mlp", 'mm')
print(f"[vl25][to_gguf_name] {og} --> {name}")
return name
def main(args): def main(args):
@ -82,7 +102,7 @@ def main(args):
np_dtype = np.float32 np_dtype = np.float32
ftype = 0 ftype = 0
elif args.data_type == 'fp16': elif args.data_type == 'fp16':
dtype = torch.float32 dtype = torch.float16
np_dtype = np.float16 np_dtype = np.float16
ftype = 1 ftype = 1
else: else:
@ -92,11 +112,18 @@ def main(args):
model_path = "" model_path = ""
model_name = args.model_name model_name = args.model_name
print("model_name: ", model_name) print("model_name: ", model_name)
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( if args.model_type == "qwen2vl":
model_name, torch_dtype=dtype, device_map="cpu" qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
) model_name, torch_dtype=dtype, device_map="cpu"
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] )
vcfg = cfg.vision_config cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
vcfg = cfg.vision_config
else:
qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="cpu"
)
cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
vcfg = cfg.vision_config
if os.path.isdir(model_name): if os.path.isdir(model_name):
local_model = True local_model = True
@ -125,14 +152,26 @@ def main(args):
else: else:
raise ValueError() raise ValueError()
tensor_map = find_vision_tensors(qwen2vl, np_dtype) if args.model_type == "qwen2.5vl":
fout.add_bool("clip.use_glu_mlp", True) # gate linear unit MLP layer in vision model
fout.add_bool("clip.use_rms_norm", True)
fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes)
fout.add_uint32("clip.vision.window_size", vcfg.window_size)
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
else:
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
if args.model_type == "qwen2.5vl":
tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
else:
tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
for name, data in tensor_map.items(): for name, data in tensor_map.items():
fout.add_tensor(name, data) fout.add_tensor(name, data)
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
@ -160,6 +199,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View file

@ -23,6 +23,9 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <limits>
#include <cassert>
#include <cmath>
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
@ -367,14 +370,14 @@ static void debug_test_mrope_2d() {
// 1. Initialize backend // 1. Initialize backend
ggml_backend_t backend = NULL; ggml_backend_t backend = NULL;
std::string backend_name = ""; std::string backend_name = "";
#ifdef GGML_USE_CUDA // #ifdef GGML_USE_CUDA
fprintf(stderr, "%s: using CUDA backend\n", __func__); // fprintf(stderr, "%s: using CUDA backend\n", __func__);
backend = ggml_backend_cuda_init(0); // init device 0 // backend = ggml_backend_cuda_init(0); // init device 0
backend_name = "cuda"; // backend_name = "cuda";
if (!backend) { // if (!backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); // fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
} // }
#endif // #endif
// if there aren't GPU Backends fallback to CPU backend // if there aren't GPU Backends fallback to CPU backend
if (!backend) { if (!backend) {
backend = ggml_backend_cpu_init(); backend = ggml_backend_cpu_init();
@ -483,28 +486,82 @@ static void debug_test_mrope_2d() {
ggml_backend_free(backend); ggml_backend_free(backend);
} }
static void debug_dump_img_embed(struct llava_context * ctx_llava) { enum model_output_type {
int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); conv3d,
int ne = n_embd * 4; patch_embed,
float vals[56 * 56 * 3]; patch_win_attn_scatter,
first_attn_layer,
last_attn_layer,
attn_softmax,
final_layer,
};
static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
constexpr int ih = 140;
constexpr int iw = 196;
// constexpr int ih = 56;
// constexpr int iw = 56;
// int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
int n_embd = 1280;
int merge = 1;
if (output_type == model_output_type::final_layer) {
n_embd = 2048;
merge = 2;
}
else if (output_type == model_output_type::attn_softmax) {
merge = 1;
n_embd = (ih/14/merge) * (iw/14/merge) * 16;
}
int ne = (ih/14/merge) * (iw/14/merge) * n_embd;
float vals[iw * ih * 3];
// float embd[ne]; // float embd[ne];
std::vector<float> embd; std::vector<float> embd;
embd.resize(ne); embd.resize(ne);
for (int i = 0; i < 56*56; i++) for (int i = 0; i < iw*ih; i++)
{ {
for (int c = 0; c < 3; c++) for (int c = 0; c < 3; c++)
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); vals[i * 3 + c] = (float)i / (iw*ih);
} }
clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data()); clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data());
std::ofstream outFile("img_embed.bin", std::ios::binary); std::string file_postfix = "";
switch (output_type)
{
case model_output_type::conv3d:
file_postfix = "conv3d";
break;
case model_output_type::patch_embed:
file_postfix = "patch_embed";
break;
case model_output_type::patch_win_attn_scatter:
file_postfix = "scatter";
break;
case model_output_type::first_attn_layer:
file_postfix = "first_attn";
break;
case model_output_type::last_attn_layer:
file_postfix = "last_attn";
break;
case model_output_type::attn_softmax:
file_postfix = "attn_softmax";
break;
case model_output_type::final_layer:
file_postfix = "final";
break;
default:
break;
}
auto output_path = "img_embed_" + file_postfix + ".bin";
std::ofstream outFile(output_path, std::ios::binary);
if (outFile.is_open()) { if (outFile.is_open()) {
outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float)); outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
outFile.close(); outFile.close();
std::cout << "Data successfully written to mrope.bin" << std::endl; std::cout << "Data successfully written to ::[ " << output_path << std::endl;
} else { } else {
std::cerr << "Error opening file!" << std::endl; std::cerr << "Error opening file!" << std::endl;
} }
@ -551,8 +608,9 @@ int main(int argc, char ** argv) {
} else if (params.image[0].empty()) { } else if (params.image[0].empty()) {
auto ctx_llava = llava_init_context(&params, model); auto ctx_llava = llava_init_context(&params, model);
debug_test_mrope_2d(); // debug_test_mrope_2d();
debug_dump_img_embed(ctx_llava); debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
// debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
llama_perf_context_print(ctx_llava->ctx_llama); llama_perf_context_print(ctx_llava->ctx_llama);
ctx_llava->model = NULL; ctx_llava->model = NULL;