kv-cache : support attention rotation for heterogeneous iSWA (#21513)

* kv-cache : support attention rotation for heterogeneous iSWA

* cont : remove assert
This commit is contained in:
Georgi Gerganov 2026-04-07 20:31:28 +03:00 committed by GitHub
parent 957d717ce5
commit 4eb19514dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 58 additions and 17 deletions

View file

@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
if (self_k_rot_swa) {
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
}
if (self_v_rot_swa) {
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
}
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}
if (inp_attn->self_k_rot_swa) {
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
}
if (inp_attn->self_v_rot_swa) {
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
const bool is_swa = hparams.is_swa(il);
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
if (k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
}
}
if (inp->self_v_rot) {
if (v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
}
}
@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(
const auto * mctx_iswa = inp->mctx;
const bool is_swa = hparams.is_swa(il);
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// optionally store to KV cache
@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
if (v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
}
if (wo) {
@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}