llama : rotate activations for better quantization (#21038)

* llama : rotate activations for better quantization

* cont : rotate V more + refactor

* cont : rotate caches separately + support non-power-of-2 head sizes

* cont : simplify

* cont : add reference for V rotation

* cont : refactor

* cont : support context shift

* cont : consolidate

* cont : dedup + allow different types for the rotation matrix

* cont : add env variable to disable rotation

* cont : simplify attn rot kv cache logic + rename env

* cont : pre-compute the Hadamard matrices
This commit is contained in:
Georgi Gerganov 2026-04-01 16:58:01 +03:00 committed by GitHub
parent 0356e33aaf
commit 744c0c7310
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 337 additions and 26 deletions

View file

@ -19,7 +19,7 @@
// dedup helpers
static ggml_tensor * build_kq_mask(
static ggml_tensor * build_attn_inp_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
return res;
}
static bool can_reuse_kq_mask(
@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
// impl
static ggml_tensor * ggml_mul_mat_aux(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat (ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
mctx->set_input_v_rot(self_v_rot);
}
}
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
if (inp_attn->self_k_rot) {
mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
}
if (inp_attn->self_v_rot) {
mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
if (inp_attn->self_k_rot) {
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
}
if (inp_attn->self_v_rot) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
return inp;
}
@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
int il) const {
GGML_ASSERT(v_mla == nullptr);
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
if (inp->self_v_rot) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
}
if (inp->self_v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur);
}
@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}