mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
merged qwen2.5vl again
This commit is contained in:
commit
88660dd59d
4 changed files with 439 additions and 117 deletions
|
@ -40,6 +40,7 @@
|
|||
#include <sstream>
|
||||
#include <cinttypes>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
|
||||
|
@ -196,6 +197,8 @@ struct clip_hparams {
|
|||
std::vector<int32_t> image_grid_pinpoints;
|
||||
int32_t image_crop_resolution;
|
||||
std::unordered_set<int32_t> vision_feature_layer;
|
||||
int32_t attn_window_size;
|
||||
std::vector<int32_t> full_attn_layers;
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
|
@ -221,6 +224,9 @@ struct clip_layer {
|
|||
struct ggml_tensor * ff_o_w = nullptr;
|
||||
struct ggml_tensor * ff_o_b = nullptr;
|
||||
|
||||
struct ggml_tensor * ff_g_w = NULL;
|
||||
struct ggml_tensor * ff_g_b = NULL;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * ln_2_w = nullptr;
|
||||
struct ggml_tensor * ln_2_b = nullptr;
|
||||
|
@ -350,6 +356,8 @@ struct clip_ctx {
|
|||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
bool use_silu = false;
|
||||
bool use_glu_mlp = false;
|
||||
bool use_rms_norm = false;
|
||||
int32_t ftype = 1;
|
||||
|
||||
struct gguf_context * ctx_gguf = nullptr;
|
||||
|
@ -599,6 +607,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const float eps = hparams.eps;
|
||||
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
|
@ -628,6 +637,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
|
@ -649,8 +659,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
}
|
||||
struct ggml_tensor * embeddings = inp;
|
||||
struct ggml_tensor * pos_embed = nullptr;
|
||||
struct ggml_tensor * embeddings = inp;
|
||||
struct ggml_tensor * pos_embed = nullptr;
|
||||
struct ggml_tensor * window_mask = nullptr;
|
||||
struct ggml_tensor * window_idx = nullptr;
|
||||
struct ggml_tensor * inv_window_idx = nullptr;
|
||||
|
||||
if (ctx->has_llava_projector) {
|
||||
// concat class_embeddings and patch_embeddings
|
||||
|
@ -692,16 +705,40 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
|
||||
// pre-layernorm
|
||||
if (model.pre_ln_w) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
if (ctx->use_rms_norm) {
|
||||
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
|
||||
} else {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<struct ggml_tensor *> embedding_stack;
|
||||
const auto & vision_feature_layer = hparams.vision_feature_layer;
|
||||
|
||||
// loop over layers
|
||||
|
||||
if (use_window_attn) {
|
||||
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
|
||||
ggml_set_name(inv_window_idx, "inv_window_idx");
|
||||
ggml_set_input(inv_window_idx);
|
||||
// mask for window attention
|
||||
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
|
||||
ggml_set_name(window_mask, "window_mask");
|
||||
ggml_set_input(window_mask);
|
||||
|
||||
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
|
||||
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
|
||||
}
|
||||
|
||||
for (int il = 0; il < ctx->max_feature_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
|
@ -714,9 +751,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
|
||||
|
||||
// layernorm1
|
||||
{
|
||||
if (ctx->use_rms_norm) {
|
||||
cur = ggml_rms_norm(ctx0, cur, eps);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
|
||||
}
|
||||
else {
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
|
||||
model.layers[il].ln_1_b);
|
||||
}
|
||||
|
@ -756,7 +796,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
|
||||
const bool full_attn = use_window_attn ? inlist : true;
|
||||
if (full_attn) {
|
||||
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||
} else {
|
||||
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||
}
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
@ -773,25 +820,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// layernorm2
|
||||
{
|
||||
if (ctx->use_rms_norm) {
|
||||
cur = ggml_rms_norm(ctx0, cur, eps);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
|
||||
} else {
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
// mlp
|
||||
if (ctx->use_glu_mlp) {
|
||||
// ffn_up
|
||||
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
|
||||
|
||||
if (ctx->use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else if (ctx->use_silu) {
|
||||
cur = ggml_silu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
|
||||
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
|
||||
if (ctx->use_gelu) {
|
||||
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
|
||||
} else if (ctx->use_silu) {
|
||||
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
|
||||
} else {
|
||||
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
|
||||
}
|
||||
cur = ggml_mul(ctx0, cur_gate, cur_up);
|
||||
|
||||
// ffn_down
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
}
|
||||
else {
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
||||
if (ctx->use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else if (ctx->use_silu) {
|
||||
cur = ggml_silu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
||||
}
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
@ -801,10 +873,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
|
||||
// post-layernorm
|
||||
if (model.post_ln_w) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
if (ctx->use_rms_norm) {
|
||||
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
|
||||
} else {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
}
|
||||
}
|
||||
|
||||
// final layer is a vision feature layer
|
||||
|
@ -1118,6 +1197,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
}
|
||||
|
||||
if (use_window_attn) {
|
||||
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
|
||||
ggml_set_name(window_idx, "window_idx");
|
||||
ggml_set_input(window_idx);
|
||||
|
||||
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
|
||||
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
|
@ -1228,6 +1319,8 @@ struct clip_model_loader {
|
|||
|
||||
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
|
||||
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
|
||||
get_bool(KEY_USE_GLU_MLP, ctx_clip.use_glu_mlp, false);
|
||||
get_bool(KEY_USE_RMS_NORM, ctx_clip.use_rms_norm, false);
|
||||
|
||||
auto & hparams = ctx_clip.vision_model.hparams;
|
||||
get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
|
||||
|
@ -1239,7 +1332,9 @@ struct clip_model_loader {
|
|||
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
|
||||
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
||||
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
||||
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false);
|
||||
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
|
||||
get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false);
|
||||
|
||||
{
|
||||
std::string mm_patch_merge_type;
|
||||
|
@ -1355,14 +1450,16 @@ struct clip_model_loader {
|
|||
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
|
||||
layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
||||
layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
||||
layer.ff_g_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), ctx_clip.use_glu_mlp);
|
||||
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
|
||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
|
||||
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
|
||||
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
|
||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
|
||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
|
||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), !ctx_clip.use_rms_norm);
|
||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), !ctx_clip.use_rms_norm);
|
||||
layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
||||
layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
||||
layer.ff_g_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), ctx_clip.use_glu_mlp);
|
||||
}
|
||||
|
||||
switch (ctx_clip.proj_type) {
|
||||
|
@ -2631,6 +2728,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
||||
free(data);
|
||||
}
|
||||
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
{
|
||||
// inspired from siglip:
|
||||
|
@ -2694,23 +2792,86 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
/*
|
||||
pw * ph = number of tokens output by ViT after apply patch merger
|
||||
ipw * ipw = number of vision token been processed inside ViT
|
||||
*/
|
||||
const int merge_ratio = 2;
|
||||
const int pw = image_size_width / patch_size / merge_ratio;
|
||||
const int ph = image_size_height / patch_size / merge_ratio;
|
||||
const int ipw = image_size_width / patch_size;
|
||||
const int iph = image_size_height / patch_size;
|
||||
|
||||
const int pw = image_size_width / patch_size;
|
||||
const int ph = image_size_height / patch_size;
|
||||
std::vector<int> idx(ph * pw);
|
||||
std::vector<int> inv_idx(ph * pw);
|
||||
|
||||
if (hparams.attn_window_size > 0) {
|
||||
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
||||
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
||||
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
||||
|
||||
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
|
||||
int dst = 0;
|
||||
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||
int mask_row = 0;
|
||||
|
||||
for (int y = 0; y < ph; y+=grid_window)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=grid_window)
|
||||
{
|
||||
const int win_h = std::min(grid_window, ph - y);
|
||||
const int win_w = std::min(grid_window, pw - x);
|
||||
const int dst_0 = dst;
|
||||
// group all tokens belong to the same window togather (to a continue range)
|
||||
for (int dy = 0; dy < win_h; dy++) {
|
||||
for (int dx = 0; dx < win_w; dx++) {
|
||||
const int src = (y + dy) * pw + (x + dx);
|
||||
assert(src < (int)idx.size());
|
||||
assert(dst < (int)inv_idx.size());
|
||||
idx[src] = dst;
|
||||
inv_idx[dst] = src;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||
int row_offset = mask_row * (ipw * iph);
|
||||
std::fill(
|
||||
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||
0.0);
|
||||
mask_row++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
||||
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
||||
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
||||
} else {
|
||||
std::iota(idx.begin(), idx.end(), 0);
|
||||
std::iota(inv_idx.begin(), inv_idx.end(), 0);
|
||||
}
|
||||
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
const int mpow = (merge_ratio * merge_ratio);
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
|
||||
int ptr = 0;
|
||||
for (int y = 0; y < ph; y+=2)
|
||||
for (int y = 0; y < iph; y+=merge_ratio)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=2)
|
||||
for (int x = 0; x < ipw; x+=merge_ratio)
|
||||
{
|
||||
for (int dy = 0; dy < 2; dy++) {
|
||||
for (int dx = 0; dx < 2; dx++) {
|
||||
positions_data[ptr] = y + dy;
|
||||
positions_data[num_patches + ptr] = x + dx;
|
||||
positions_data[num_patches * 2 + ptr] = y + dy;
|
||||
positions_data[num_patches * 3 + ptr] = x + dx;
|
||||
auto remap = idx[ptr / mpow];
|
||||
remap = remap * mpow + (ptr % mpow);
|
||||
|
||||
positions_data[remap] = y + dy;
|
||||
positions_data[num_patches + remap] = x + dx;
|
||||
positions_data[num_patches * 2 + remap] = y + dy;
|
||||
positions_data[num_patches * 3 + remap] = x + dx;
|
||||
ptr++;
|
||||
}
|
||||
}
|
||||
|
@ -2749,6 +2910,64 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
}
|
||||
|
||||
if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn?
|
||||
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
||||
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
||||
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
||||
|
||||
const int merge_ratio = 2;
|
||||
const int pw = image_size_width / patch_size / merge_ratio;
|
||||
const int ph = image_size_height / patch_size / merge_ratio;
|
||||
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
|
||||
const int ipw = image_size_width / patch_size;
|
||||
const int iph = image_size_height / patch_size;
|
||||
/*
|
||||
pw * ph = number of tokens output by ViT after apply patch merger
|
||||
ipw * ipw = number of vision token been processed inside ViT
|
||||
*/
|
||||
|
||||
std::vector<int> idx(ph * pw);
|
||||
std::vector<int> inv_idx(ph * pw);
|
||||
int dst = 0;
|
||||
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||
int mask_row = 0;
|
||||
|
||||
for (int y = 0; y < ph; y+=grid_window)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=grid_window)
|
||||
{
|
||||
const int win_h = std::min(grid_window, ph - y);
|
||||
const int win_w = std::min(grid_window, pw - x);
|
||||
const int dst_0 = dst;
|
||||
// group all tokens belong to the same window togather (to a continue range)
|
||||
for (int dy = 0; dy < win_h; dy++) {
|
||||
for (int dx = 0; dx < win_w; dx++) {
|
||||
const int src = (y + dy) * pw + (x + dx);
|
||||
assert(src < (int)idx.size());
|
||||
assert(dst < (int)inv_idx.size());
|
||||
idx[src] = dst;
|
||||
inv_idx[dst] = src;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||
int row_offset = mask_row * (ipw * iph);
|
||||
std::fill(
|
||||
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||
0.0);
|
||||
mask_row++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
||||
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
||||
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
||||
}
|
||||
|
||||
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue