mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
wip qwen25vl merge
This commit is contained in:
commit
1b0481f4b1
5 changed files with 418 additions and 24 deletions
|
@ -2554,11 +2554,12 @@ class Qwen2VLModel(TextModel):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
for name, data in super().get_tensors():
|
del bid # unused
|
||||||
if name.startswith("visual."):
|
if name.startswith("visual."):
|
||||||
continue
|
# skip visual tensors
|
||||||
yield name, data
|
return []
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("WavTokenizerDec")
|
@ModelBase.register("WavTokenizerDec")
|
||||||
|
|
|
@ -34,9 +34,14 @@
|
||||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
|
||||||
|
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
|
||||||
|
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
|
||||||
|
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||||
|
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||||
|
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -55,6 +60,7 @@
|
||||||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||||
|
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||||
|
@ -95,6 +101,7 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_GEMMA3,
|
PROJECTOR_TYPE_GEMMA3,
|
||||||
PROJECTOR_TYPE_IDEFICS3,
|
PROJECTOR_TYPE_IDEFICS3,
|
||||||
PROJECTOR_TYPE_PIXTRAL,
|
PROJECTOR_TYPE_PIXTRAL,
|
||||||
|
PROJECTOR_TYPE_QWEN25VL,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -105,6 +112,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
|
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
|
||||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||||
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
|
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
|
||||||
|
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
|
||||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||||
|
|
|
@ -42,6 +42,7 @@
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||||
|
|
||||||
|
@ -183,6 +184,8 @@ struct clip_hparams {
|
||||||
std::vector<int32_t> image_grid_pinpoints;
|
std::vector<int32_t> image_grid_pinpoints;
|
||||||
int32_t image_crop_resolution;
|
int32_t image_crop_resolution;
|
||||||
std::unordered_set<int32_t> vision_feature_layer;
|
std::unordered_set<int32_t> vision_feature_layer;
|
||||||
|
int32_t attn_window_size;
|
||||||
|
int32_t n_wa_pattern;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
@ -214,6 +217,9 @@ struct clip_layer {
|
||||||
struct ggml_tensor * ff_down_w = nullptr;
|
struct ggml_tensor * ff_down_w = nullptr;
|
||||||
struct ggml_tensor * ff_down_b = nullptr;
|
struct ggml_tensor * ff_down_b = nullptr;
|
||||||
|
|
||||||
|
struct ggml_tensor * ff_g_w = NULL;
|
||||||
|
struct ggml_tensor * ff_g_b = NULL;
|
||||||
|
|
||||||
// layernorm 2
|
// layernorm 2
|
||||||
struct ggml_tensor * ln_2_w = nullptr;
|
struct ggml_tensor * ln_2_w = nullptr;
|
||||||
struct ggml_tensor * ln_2_b = nullptr;
|
struct ggml_tensor * ln_2_b = nullptr;
|
||||||
|
@ -339,6 +345,7 @@ struct clip_ctx {
|
||||||
float image_std[3];
|
float image_std[3];
|
||||||
bool use_gelu = false;
|
bool use_gelu = false;
|
||||||
bool use_silu = false;
|
bool use_silu = false;
|
||||||
|
int32_t ftype = 1;
|
||||||
|
|
||||||
gguf_context_ptr ctx_gguf;
|
gguf_context_ptr ctx_gguf;
|
||||||
ggml_context_ptr ctx_data;
|
ggml_context_ptr ctx_data;
|
||||||
|
@ -787,6 +794,236 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||||
|
const auto & model = ctx->vision_model;
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
const int image_size_width = imgs.entries[0]->nx;
|
||||||
|
const int image_size_height = imgs.entries[0]->ny;
|
||||||
|
|
||||||
|
const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
|
||||||
|
const bool use_window_attn = hparams.n_wa_pattern > 0;
|
||||||
|
|
||||||
|
const int n_wa_pattern = hparams.n_wa_pattern;
|
||||||
|
const int patch_size = hparams.patch_size;
|
||||||
|
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||||
|
const int patches_w = image_size_width / patch_size;
|
||||||
|
const int patches_h = image_size_height / patch_size;
|
||||||
|
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||||
|
const int num_position_ids = use_mrope ? num_positions * 4 : num_positions;
|
||||||
|
const int hidden_size = hparams.hidden_size;
|
||||||
|
const int n_head = hparams.n_head;
|
||||||
|
const int d_head = hidden_size / n_head;
|
||||||
|
const float eps = hparams.eps;
|
||||||
|
|
||||||
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||||
|
|
||||||
|
const int batch_size = imgs.entries.size();
|
||||||
|
GGML_ASSERT(batch_size == 1);
|
||||||
|
|
||||||
|
struct ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||||
|
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_context_ptr ctx0_ptr(ggml_init(params));
|
||||||
|
auto ctx0 = ctx0_ptr.get();
|
||||||
|
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
|
||||||
|
ggml_set_name(inp_raw, "inp_raw");
|
||||||
|
ggml_set_input(inp_raw);
|
||||||
|
|
||||||
|
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
|
||||||
|
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
|
||||||
|
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
|
||||||
|
|
||||||
|
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
|
inp = ggml_add(ctx0, inp, inp_1);
|
||||||
|
|
||||||
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||||
|
inp = ggml_reshape_4d(
|
||||||
|
ctx0, inp,
|
||||||
|
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
||||||
|
inp = ggml_reshape_4d(
|
||||||
|
ctx0, inp,
|
||||||
|
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||||
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||||
|
inp = ggml_reshape_3d(
|
||||||
|
ctx0, inp,
|
||||||
|
hidden_size, patches_w * patches_h, batch_size);
|
||||||
|
|
||||||
|
if (model.patch_bias) {
|
||||||
|
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||||
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||||
|
}
|
||||||
|
struct ggml_tensor * embeddings = inp;
|
||||||
|
struct ggml_tensor * window_mask = nullptr;
|
||||||
|
struct ggml_tensor * window_idx = nullptr;
|
||||||
|
struct ggml_tensor * inv_window_idx = nullptr;
|
||||||
|
|
||||||
|
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||||
|
ggml_set_name(positions, "positions");
|
||||||
|
ggml_set_input(positions);
|
||||||
|
|
||||||
|
// pre-layernorm
|
||||||
|
if (model.pre_ln_w) {
|
||||||
|
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||||
|
ggml_set_name(embeddings, "pre_ln");
|
||||||
|
|
||||||
|
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_window_attn) {
|
||||||
|
// handle window attention inputs
|
||||||
|
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
|
||||||
|
ggml_set_name(inv_window_idx, "inv_window_idx");
|
||||||
|
ggml_set_input(inv_window_idx);
|
||||||
|
// mask for window attention
|
||||||
|
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
|
||||||
|
ggml_set_name(window_mask, "window_mask");
|
||||||
|
ggml_set_input(window_mask);
|
||||||
|
|
||||||
|
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
||||||
|
GGML_ASSERT(batch_size == 1);
|
||||||
|
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
|
||||||
|
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
|
||||||
|
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over layers
|
||||||
|
for (int il = 0; il < ctx->max_feature_layer; il++) {
|
||||||
|
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||||
|
|
||||||
|
// rmsnorm1
|
||||||
|
cur = ggml_rms_norm(ctx0, cur, eps);
|
||||||
|
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
|
||||||
|
struct ggml_tensor * Q =
|
||||||
|
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||||
|
|
||||||
|
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||||
|
Q = ggml_rope_multi(
|
||||||
|
ctx0, Q, positions, nullptr,
|
||||||
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||||
|
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||||
|
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||||
|
|
||||||
|
struct ggml_tensor * K =
|
||||||
|
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||||
|
|
||||||
|
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||||
|
K = ggml_rope_multi(
|
||||||
|
ctx0, K, positions, nullptr,
|
||||||
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||||
|
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||||
|
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||||
|
|
||||||
|
struct ggml_tensor * V =
|
||||||
|
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
|
||||||
|
|
||||||
|
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||||
|
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||||
|
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||||
|
|
||||||
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
|
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
|
||||||
|
if (full_attn) {
|
||||||
|
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||||
|
} else {
|
||||||
|
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||||
|
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||||
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
|
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// attention output
|
||||||
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
|
||||||
|
|
||||||
|
// re-add the layer input, e.g., residual
|
||||||
|
cur = ggml_add(ctx0, cur, embeddings);
|
||||||
|
|
||||||
|
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||||
|
|
||||||
|
// rms norm2
|
||||||
|
cur = ggml_rms_norm(ctx0, cur, eps);
|
||||||
|
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
|
||||||
|
|
||||||
|
// mlp
|
||||||
|
// ffn_up
|
||||||
|
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||||
|
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
|
||||||
|
|
||||||
|
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
|
||||||
|
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
|
||||||
|
// TODO : only 2 of these 3 are actually used, should we remove one of them?
|
||||||
|
if (ctx->use_gelu) {
|
||||||
|
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
|
||||||
|
} else if (ctx->use_silu) {
|
||||||
|
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
|
||||||
|
} else {
|
||||||
|
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
|
||||||
|
}
|
||||||
|
cur = ggml_mul(ctx0, cur_gate, cur_up);
|
||||||
|
|
||||||
|
// ffn_down
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||||
|
|
||||||
|
// residual 2
|
||||||
|
cur = ggml_add(ctx0, embeddings, cur);
|
||||||
|
|
||||||
|
embeddings = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// post-layernorm
|
||||||
|
if (model.post_ln_w) {
|
||||||
|
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||||
|
ggml_set_name(embeddings, "post_ln");
|
||||||
|
|
||||||
|
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||||
|
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
|
||||||
|
// GELU activation
|
||||||
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
|
||||||
|
// Second linear layer
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||||
|
|
||||||
|
if (use_window_attn) {
|
||||||
|
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
|
||||||
|
ggml_set_name(window_idx, "window_idx");
|
||||||
|
ggml_set_input(window_idx);
|
||||||
|
|
||||||
|
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
||||||
|
GGML_ASSERT(batch_size == 1);
|
||||||
|
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
|
||||||
|
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
|
||||||
|
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the graph
|
||||||
|
ggml_build_forward_expand(gf, embeddings);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||||
const auto & model = ctx->vision_model;
|
const auto & model = ctx->vision_model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
@ -1356,6 +1593,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
GGML_ASSERT(imgs.entries.size() == 1);
|
GGML_ASSERT(imgs.entries.size() == 1);
|
||||||
res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
|
res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
{
|
||||||
|
res = clip_image_build_graph_qwen25vl(ctx, imgs);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
// TODO: we should have one build_* function per model
|
// TODO: we should have one build_* function per model
|
||||||
|
@ -1532,6 +1773,10 @@ struct clip_model_loader {
|
||||||
{
|
{
|
||||||
hparams.rope_theta = 10000.0f;
|
hparams.rope_theta = 10000.0f;
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
|
{
|
||||||
|
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1625,8 +1870,10 @@ struct clip_model_loader {
|
||||||
// legacy naming (the in and out is reversed! don't ask me why)
|
// legacy naming (the in and out is reversed! don't ask me why)
|
||||||
layer.ff_i_w = layer.ff_down_w;
|
layer.ff_i_w = layer.ff_down_w;
|
||||||
layer.ff_o_w = layer.ff_up_w;
|
layer.ff_o_w = layer.ff_up_w;
|
||||||
|
layer.ff_g_w = layer.ff_gate_w;
|
||||||
layer.ff_i_b = layer.ff_down_b;
|
layer.ff_i_b = layer.ff_down_b;
|
||||||
layer.ff_o_b = layer.ff_up_b;
|
layer.ff_o_b = layer.ff_up_b;
|
||||||
|
layer.ff_g_b = layer.ff_gate_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (ctx_clip.proj_type) {
|
switch (ctx_clip.proj_type) {
|
||||||
|
@ -1725,6 +1972,7 @@ struct clip_model_loader {
|
||||||
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
|
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
{
|
{
|
||||||
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||||
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
||||||
|
@ -2767,7 +3015,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
||||||
else {
|
else {
|
||||||
GGML_ABORT("Unknown minicpmv version");
|
GGML_ABORT("Unknown minicpmv version");
|
||||||
}
|
}
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
||||||
int patch_size = params.patch_size * 2;
|
int patch_size = params.patch_size * 2;
|
||||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||||
|
@ -2908,6 +3156,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
const int pos_w = ctx->load_image_size.width / patch_size;
|
const int pos_w = ctx->load_image_size.width / patch_size;
|
||||||
const int pos_h = ctx->load_image_size.height / patch_size;
|
const int pos_h = ctx->load_image_size.height / patch_size;
|
||||||
|
|
||||||
|
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
|
||||||
|
|
||||||
{
|
{
|
||||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||||
std::vector<float> inp_data(ggml_nelements(inp_raw));
|
std::vector<float> inp_data(ggml_nelements(inp_raw));
|
||||||
|
@ -3006,31 +3256,93 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
// non-minicpmv models
|
// non-minicpmv models
|
||||||
|
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
// pw * ph = number of tokens output by ViT after apply patch merger
|
||||||
|
// ipw * ipw = number of vision token been processed inside ViT
|
||||||
|
const int merge_ratio = 2;
|
||||||
|
const int pw = image_size_width / patch_size / merge_ratio;
|
||||||
|
const int ph = image_size_height / patch_size / merge_ratio;
|
||||||
|
const int ipw = image_size_width / patch_size;
|
||||||
|
const int iph = image_size_height / patch_size;
|
||||||
|
|
||||||
const int pw = image_size_width / patch_size;
|
std::vector<int> idx (ph * pw);
|
||||||
const int ph = image_size_height / patch_size;
|
std::vector<int> inv_idx(ph * pw);
|
||||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
|
||||||
|
if (use_window_attn) {
|
||||||
|
const int attn_window_size = 112;
|
||||||
|
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
||||||
|
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
||||||
|
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
||||||
|
|
||||||
|
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
||||||
|
int dst = 0;
|
||||||
|
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||||
|
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||||
|
int mask_row = 0;
|
||||||
|
|
||||||
|
for (int y = 0; y < ph; y += grid_window)
|
||||||
|
{
|
||||||
|
for (int x = 0; x < pw; x += grid_window)
|
||||||
|
{
|
||||||
|
const int win_h = std::min(grid_window, ph - y);
|
||||||
|
const int win_w = std::min(grid_window, pw - x);
|
||||||
|
const int dst_0 = dst;
|
||||||
|
// group all tokens belong to the same window togather (to a continue range)
|
||||||
|
for (int dy = 0; dy < win_h; dy++) {
|
||||||
|
for (int dx = 0; dx < win_w; dx++) {
|
||||||
|
const int src = (y + dy) * pw + (x + dx);
|
||||||
|
assert(src < (int)idx.size());
|
||||||
|
assert(dst < (int)inv_idx.size());
|
||||||
|
idx [src] = dst;
|
||||||
|
inv_idx[dst] = src;
|
||||||
|
dst++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||||
|
int row_offset = mask_row * (ipw * iph);
|
||||||
|
std::fill(
|
||||||
|
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||||
|
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||||
|
0.0);
|
||||||
|
mask_row++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
||||||
|
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
||||||
|
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
||||||
|
} else {
|
||||||
|
std::iota(idx.begin(), idx.end(), 0);
|
||||||
|
std::iota(inv_idx.begin(), inv_idx.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||||
|
const int mpow = merge_ratio * merge_ratio;
|
||||||
|
std::vector<int> positions_data(ggml_nelements(positions));
|
||||||
|
int * data = positions_data.data();
|
||||||
|
|
||||||
int ptr = 0;
|
int ptr = 0;
|
||||||
for (int y = 0; y < ph; y+=2)
|
for (int y = 0; y < iph; y += merge_ratio)
|
||||||
{
|
{
|
||||||
for (int x = 0; x < pw; x+=2)
|
for (int x = 0; x < ipw; x += merge_ratio)
|
||||||
{
|
{
|
||||||
for (int dy = 0; dy < 2; dy++) {
|
for (int dy = 0; dy < 2; dy++) {
|
||||||
for (int dx = 0; dx < 2; dx++) {
|
for (int dx = 0; dx < 2; dx++) {
|
||||||
positions_data[ptr] = y + dy;
|
auto remap = idx[ptr / mpow];
|
||||||
positions_data[num_patches + ptr] = x + dx;
|
remap = remap * mpow + (ptr % mpow);
|
||||||
positions_data[num_patches * 2 + ptr] = y + dy;
|
|
||||||
positions_data[num_patches * 3 + ptr] = x + dx;
|
data[ remap] = y + dy;
|
||||||
|
data[ num_patches + remap] = x + dx;
|
||||||
|
data[2 * num_patches + remap] = y + dy;
|
||||||
|
data[3 * num_patches + remap] = x + dx;
|
||||||
ptr++;
|
ptr++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
|
||||||
free(positions_data);
|
|
||||||
}
|
}
|
||||||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||||
// do nothing
|
// do nothing
|
||||||
|
@ -3083,6 +3395,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
||||||
|
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
||||||
|
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
||||||
|
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
||||||
|
|
||||||
|
const int merge_ratio = 2;
|
||||||
|
const int attn_window_size = 112;
|
||||||
|
const int pw = image_size_width / patch_size / merge_ratio;
|
||||||
|
const int ph = image_size_height / patch_size / merge_ratio;
|
||||||
|
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
||||||
|
const int ipw = image_size_width / patch_size;
|
||||||
|
const int iph = image_size_height / patch_size;
|
||||||
|
/*
|
||||||
|
pw * ph = number of tokens output by ViT after apply patch merger
|
||||||
|
ipw * ipw = number of vision token been processed inside ViT
|
||||||
|
*/
|
||||||
|
|
||||||
|
std::vector<int> idx(ph * pw);
|
||||||
|
std::vector<int> inv_idx(ph * pw);
|
||||||
|
int dst = 0;
|
||||||
|
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||||
|
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||||
|
int mask_row = 0;
|
||||||
|
|
||||||
|
for (int y = 0; y < ph; y+=grid_window)
|
||||||
|
{
|
||||||
|
for (int x = 0; x < pw; x+=grid_window)
|
||||||
|
{
|
||||||
|
const int win_h = std::min(grid_window, ph - y);
|
||||||
|
const int win_w = std::min(grid_window, pw - x);
|
||||||
|
const int dst_0 = dst;
|
||||||
|
// group all tokens belong to the same window togather (to a continue range)
|
||||||
|
for (int dy = 0; dy < win_h; dy++) {
|
||||||
|
for (int dx = 0; dx < win_w; dx++) {
|
||||||
|
const int src = (y + dy) * pw + (x + dx);
|
||||||
|
assert(src < (int)idx.size());
|
||||||
|
assert(dst < (int)inv_idx.size());
|
||||||
|
idx[src] = dst;
|
||||||
|
inv_idx[dst] = src;
|
||||||
|
dst++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||||
|
int row_offset = mask_row * (ipw * iph);
|
||||||
|
std::fill(
|
||||||
|
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||||
|
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||||
|
0.0);
|
||||||
|
mask_row++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
||||||
|
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
||||||
|
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
||||||
|
}
|
||||||
|
|
||||||
if (ggml_backend_is_cpu(ctx->backend)) {
|
if (ggml_backend_is_cpu(ctx->backend)) {
|
||||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||||
}
|
}
|
||||||
|
@ -3275,6 +3646,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
case PROJECTOR_TYPE_GLM_EDGE:
|
case PROJECTOR_TYPE_GLM_EDGE:
|
||||||
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
||||||
case PROJECTOR_TYPE_QWEN2VL:
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
return ctx->vision_model.mm_1_b->ne[0];
|
return ctx->vision_model.mm_1_b->ne[0];
|
||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Dict
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -20,6 +20,20 @@ VISION = "clip.vision"
|
||||||
def k(raw_key: str, arch: str) -> str:
|
def k(raw_key: str, arch: str) -> str:
|
||||||
return raw_key.format(arch=arch)
|
return raw_key.format(arch=arch)
|
||||||
|
|
||||||
|
|
||||||
|
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
|
||||||
|
if fullatt_block_indexes is None:
|
||||||
|
return 0
|
||||||
|
n_wa = fullatt_block_indexes[0]
|
||||||
|
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
|
||||||
|
if b - a - 1 != n_wa:
|
||||||
|
raise ValueError(
|
||||||
|
f"window/full attention layer should have fix pattern of "
|
||||||
|
f"for each full-attention layer followed by {n_wa} window-attention layers"
|
||||||
|
)
|
||||||
|
return n_wa + 1
|
||||||
|
|
||||||
|
|
||||||
class VL2:
|
class VL2:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -140,7 +154,6 @@ def main(args):
|
||||||
fout.add_bool("clip.has_text_encoder", False)
|
fout.add_bool("clip.has_text_encoder", False)
|
||||||
fout.add_bool("clip.has_vision_encoder", True)
|
fout.add_bool("clip.has_vision_encoder", True)
|
||||||
fout.add_bool("clip.has_qwen2vl_merger", True)
|
fout.add_bool("clip.has_qwen2vl_merger", True)
|
||||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
|
||||||
|
|
||||||
print(cfg.vision_config)
|
print(cfg.vision_config)
|
||||||
if 'silu' in cfg.vision_config.hidden_act.lower():
|
if 'silu' in cfg.vision_config.hidden_act.lower():
|
||||||
|
@ -153,13 +166,12 @@ def main(args):
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
if args.model_type == "qwen2.5vl":
|
if args.model_type == "qwen2.5vl":
|
||||||
fout.add_bool("clip.use_glu_mlp", True) # gate linear unit MLP layer in vision model
|
fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
|
||||||
fout.add_bool("clip.use_rms_norm", True)
|
|
||||||
fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes)
|
|
||||||
fout.add_uint32("clip.vision.window_size", vcfg.window_size)
|
|
||||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
|
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
|
||||||
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
|
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
|
||||||
|
fout.add_string("clip.projector_type", "qwen2.5vl_merger")
|
||||||
else:
|
else:
|
||||||
|
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # mode
|
||||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
||||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||||
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||||
|
add_test "llama-qwen2vl-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
||||||
|
|
||||||
# to test the big models, run: ./tests.sh big
|
# to test the big models, run: ./tests.sh big
|
||||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue