mtmd: n_head_kv defaults to n_head (#23782)

removed AI-generated comment
This commit is contained in:
Saba Fallah 2026-05-28 16:44:36 +02:00 committed by GitHub
parent d6be3158e1
commit 0b56d283bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View file

@ -29,6 +29,7 @@ struct clip_graph {
const int n_patches;
const int n_embd;
const int n_head;
const int n_head_kv;
const int d_head;
const int n_layer;
const int n_mmproj_embd;

View file

@ -246,6 +246,7 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
n_patches(n_patches_x * n_patches_y),
n_embd(hparams.n_embd),
n_head(hparams.n_head),
n_head_kv(hparams.n_head_kv),
d_head(n_embd / n_head),
n_layer(hparams.n_layer),
n_mmproj_embd(clip_n_mmproj_embd(ctx)),
@ -401,9 +402,9 @@ ggml_tensor * clip_graph::build_vit(
}
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos);
if (norm_per_head) {
if (layer.q_norm) {
@ -1120,6 +1121,9 @@ struct clip_model_loader {
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
// n_head_kv is optional (for GQA), default to n_head
hparams.n_head_kv = hparams.n_head;
if (is_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);