model : refactor QKV into common build_qkv and create_tensor_qkv helpers (#21245)

* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
This commit is contained in:
PikaPikachu 2026-04-16 23:41:34 +08:00 committed by GitHub
parent f772f6e434
commit 9db77a020c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 351 additions and 1764 deletions

View file

@ -1,6 +1,7 @@
#include "llama-graph.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-batch.h"
#include "llama-cparams.h"
@ -1059,6 +1060,84 @@ ggml_tensor * llm_graph_context::build_norm(
return cur;
}
llm_graph_qkv llm_graph_context::build_qkv(
const llama_layer & layer,
ggml_tensor * cur,
int64_t n_embd_head,
int64_t n_head,
int64_t n_head_kv,
int il) const {
const int64_t n_embd_q = n_embd_head * n_head;
const int64_t n_embd_kv = n_embd_head * n_head_kv;
ggml_tensor * Qcur, * Kcur, * Vcur;
if (layer.wqkv) {
// fused QKV path
ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s);
cb(qkv, "wqkv", il);
if (layer.bqkv) {
qkv = ggml_add(ctx0, qkv, layer.bqkv);
cb(qkv, "bqkv", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(qkv, "wqkv_clamped", il);
}
Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens,
ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0);
Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
ggml_row_size(qkv->type, n_embd_head), qkv->nb[1],
ggml_row_size(qkv->type, n_embd_q));
Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
ggml_row_size(qkv->type, n_embd_head), qkv->nb[1],
ggml_row_size(qkv->type, n_embd_q + n_embd_kv));
} else {
// separate Q/K/V path
Qcur = build_lora_mm(layer.wq, cur, layer.wq_s);
cb(Qcur, "Qcur", il);
if (layer.bq) {
Qcur = ggml_add(ctx0, Qcur, layer.bq);
cb(Qcur, "Qcur", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(Qcur, "Qcur_clamped", il);
}
Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
cb(Kcur, "Kcur", il);
if (layer.bk) {
Kcur = ggml_add(ctx0, Kcur, layer.bk);
cb(Kcur, "Kcur", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(Kcur, "Kcur_clamped", il);
}
Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
cb(Vcur, "Vcur", il);
if (layer.bv) {
Vcur = ggml_add(ctx0, Vcur, layer.bv);
cb(Vcur, "Vcur", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(Vcur, "Vcur_clamped", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
return { Qcur, Kcur, Vcur };
}
ggml_tensor * llm_graph_context::build_ffn(
ggml_tensor * cur,
ggml_tensor * up,