add tentative support for qwen2.5vl vision from HimariO fork

This commit is contained in:
Concedo 2025-03-29 22:52:43 +08:00
parent 396875e1c4
commit 911669087a

View file

@ -39,6 +39,7 @@
#include <sstream>
#include <cinttypes>
#include <limits>
#include <numeric>
#if defined(LLAVA_LOG_OFF)
# define LOG_INF(...)
@ -102,6 +103,8 @@ static std::string format(const char * fmt, ...) {
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_USE_GLU_MLP "clip.use_glu_mlp"
#define KEY_USE_RMS_NORM "clip.use_rms_norm"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
@ -120,6 +123,8 @@ static std::string format(const char * fmt, ...) {
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
//
@ -138,6 +143,7 @@ static std::string format(const char * fmt, ...) {
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s"
@ -447,6 +453,8 @@ struct clip_hparams {
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size;
std::vector<int32_t> full_attn_layers;
};
struct clip_layer {
@ -472,6 +480,9 @@ struct clip_layer {
struct ggml_tensor * ff_o_w;
struct ggml_tensor * ff_o_b;
struct ggml_tensor * ff_g_w = NULL;
struct ggml_tensor * ff_g_b = NULL;
// layernorm 2
struct ggml_tensor * ln_2_w;
struct ggml_tensor * ln_2_b;
@ -601,6 +612,8 @@ struct clip_ctx {
float image_std[3];
bool use_gelu = false;
bool use_silu = false;
bool use_glu_mlp = false;
bool use_rms_norm = false;
int32_t ftype = 1;
bool has_class_embedding = true;
@ -856,6 +869,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const float eps = hparams.eps;
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs->size;
@ -906,8 +920,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
}
struct ggml_tensor * embeddings = inp;
struct ggml_tensor * pos_embed = nullptr;
struct ggml_tensor * embeddings = inp;
struct ggml_tensor * pos_embed = nullptr;
struct ggml_tensor * window_mask = nullptr;
struct ggml_tensor * window_idx = nullptr;
struct ggml_tensor * inv_window_idx = nullptr;
if (ctx->has_llava_projector) {
// concat class_embeddings and patch_embeddings
@ -949,16 +966,41 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
// pre-layernorm
if (ctx->has_pre_norm) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln");
if (ctx->use_rms_norm) {
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
} else {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
}
}
std::vector<struct ggml_tensor *> embedding_stack;
const auto & vision_feature_layer = hparams.vision_feature_layer;
// loop over layers
if (use_window_attn) {
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
ggml_set_name(inv_window_idx, "inv_window_idx");
ggml_set_input(inv_window_idx);
// mask for window attention
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
ggml_set_name(window_mask, "window_mask");
ggml_set_input(window_mask);
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
}
for (int il = 0; il < ctx->max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
@ -971,9 +1013,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
// layernorm1
{
if (ctx->use_rms_norm) {
cur = ggml_rms_norm(ctx0, cur, eps);
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
}
else {
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
model.layers[il].ln_1_b);
}
@ -1014,7 +1059,15 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_inplace(ctx0, KQ);
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
const bool full_attn = use_window_attn ? inlist : true;
if (full_attn) {
KQ = ggml_soft_max_inplace(ctx0, KQ);
} else {
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f);
}
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -1031,25 +1084,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
embeddings = cur; // embeddings = residual, cur = hidden_states
// layernorm2
{
if (ctx->use_rms_norm) {
cur = ggml_rms_norm(ctx0, cur, eps);
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
} else {
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
}
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
// mlp
if (ctx->use_glu_mlp) {
// ffn_up
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
if (ctx->use_gelu) {
cur = ggml_gelu_inplace(ctx0, cur);
} else if (ctx->use_silu) {
cur = ggml_silu_inplace(ctx0, cur);
} else {
cur = ggml_gelu_quick_inplace(ctx0, cur);
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
if (ctx->use_gelu) {
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
} else if (ctx->use_silu) {
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
} else {
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
}
cur = ggml_mul(ctx0, cur_gate, cur_up);
// ffn_down
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
}
else {
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
if (ctx->use_gelu) {
cur = ggml_gelu_inplace(ctx0, cur);
} else if (ctx->use_silu) {
cur = ggml_silu_inplace(ctx0, cur);
} else {
cur = ggml_gelu_quick_inplace(ctx0, cur);
}
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
}
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
@ -1059,10 +1137,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
// post-layernorm
if (ctx->has_post_norm) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
if (ctx->use_rms_norm) {
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
} else {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
}
// final layer is a vision feature layer
@ -1375,6 +1460,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
}
if (use_window_attn) {
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
ggml_set_name(window_idx, "window_idx");
ggml_set_input(window_idx);
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
@ -1569,6 +1666,20 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
new_clip->use_silu = false;
}
try {
idx = get_key_idx(ctx, KEY_USE_GLU_MLP);
new_clip->use_glu_mlp = gguf_get_val_bool(ctx, idx);
} catch (std::runtime_error & /*e*/) {
new_clip->use_glu_mlp = false;
}
try {
idx = get_key_idx(ctx, KEY_USE_RMS_NORM);
new_clip->use_rms_norm = gguf_get_val_bool(ctx, idx);
} catch (std::runtime_error & /*e*/) {
new_clip->use_rms_norm = false;
}
if (verbosity >= 1) {
LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
@ -1703,6 +1814,18 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean);
const float * std_data = (const float *)gguf_get_arr_data(ctx, idx_std);
try {
int idx_full_attn_layers = get_key_idx(ctx, KEY_FULLATTN_BLK_IDX);
auto n_full_attn_layers = gguf_get_arr_n(ctx, idx_full_attn_layers);
const int * full_attn_layers = (const int *)gguf_get_arr_data(ctx, idx_full_attn_layers);
hparams.full_attn_layers.assign(full_attn_layers, full_attn_layers + n_full_attn_layers);
int idx_window_size = get_key_idx(ctx, KEY_ATTN_WINDOW_SIZE);
hparams.attn_window_size = gguf_get_val_u32(ctx, idx_window_size);
} catch (std::runtime_error & /*e*/) {
hparams.attn_window_size = 0;
}
for (int i = 0; i < 3; ++i) {
new_clip->image_mean[i] = mean_data[i];
new_clip->image_std[i] = std_data[i];
@ -1753,8 +1876,15 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
}
try {
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias"));
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
new_clip->has_post_norm = true;
} catch (std::exception & /*e*/) {
new_clip->has_post_norm = false;
}
try {
// in case of rms norm, there will be only ln weight
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
new_clip->has_post_norm = true;
} catch (std::exception & /*e*/) {
new_clip->has_post_norm = false;
@ -1914,10 +2044,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
layer.q_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "bias"));
layer.v_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "bias"));
layer.o_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias"));
layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias"));
layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias"));
layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "bias"));
layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "bias"));
if (!new_clip->use_rms_norm) {
layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias"));
layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias"));
}
if (new_clip->use_glu_mlp) {
layer.ff_g_w = get_tensor(new_clip->ctx_data, format(TN_FFN_GATE, "v", il, "weight"));
layer.ff_g_b = get_tensor(new_clip->ctx_data, format(TN_FFN_GATE, "v", il, "bias"));
}
}
}
@ -3024,30 +3161,96 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
if (ctx->has_qwen2vl_merger) {
/*
pw * ph = number of tokens output by ViT after apply patch merger
ipw * ipw = number of vision token been processed inside ViT
*/
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
std::vector<int> idx(ph * pw);
std::vector<int> inv_idx(ph * pw);
if (hparams.attn_window_size > 0) {
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y+=grid_window)
{
for (int x = 0; x < pw; x+=grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
assert(src < (int)idx.size());
assert(dst < (int)inv_idx.size());
idx[src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
} else {
std::iota(idx.begin(), idx.end(), 0);
std::iota(inv_idx.begin(), inv_idx.end(), 0);
}
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
const int pw = image_size_width / patch_size;
const int ph = image_size_height / patch_size;
// const int pw = image_size_width / patch_size;
// const int ph = image_size_height / patch_size;
const int mpow = (merge_ratio * merge_ratio);
int* positions_data = (int*)malloc(ggml_nbytes(positions));
int ptr = 0;
for (int y = 0; y < ph; y+=2)
for (int y = 0; y < iph; y+=merge_ratio)
{
for (int x = 0; x < pw; x+=2)
for (int x = 0; x < ipw; x+=merge_ratio)
{
for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) {
positions_data[ptr] = y + dy;
positions_data[num_patches + ptr] = x + dx;
positions_data[num_patches * 2 + ptr] = y + dy;
positions_data[num_patches * 3 + ptr] = x + dx;
auto remap = idx[ptr / mpow];
remap = remap * mpow + (ptr % mpow);
positions_data[remap] = y + dy;
positions_data[num_patches + remap] = x + dx;
positions_data[num_patches * 2 + remap] = y + dy;
positions_data[num_patches * 3 + remap] = x + dx;
ptr++;
}
}
}
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
if (positions) ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
@ -3079,6 +3282,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn?
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
/*
pw * ph = number of tokens output by ViT after apply patch merger
ipw * ipw = number of vision token been processed inside ViT
*/
std::vector<int> idx(ph * pw);
std::vector<int> inv_idx(ph * pw);
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y+=grid_window)
{
for (int x = 0; x < pw; x+=grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
assert(src < (int)idx.size());
assert(dst < (int)inv_idx.size());
idx[src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
}
if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
}