mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-05 23:41:45 +00:00
kv-cache : support attention rotation for heterogeneous iSWA (#21513)
* kv-cache : support attention rotation for heterogeneous iSWA * cont : remove assert
This commit is contained in:
parent
957d717ce5
commit
4eb19514dd
4 changed files with 58 additions and 17 deletions
|
|
@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
if (self_v_rot) {
|
||||
mctx->get_base()->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
|
||||
if (self_k_rot_swa) {
|
||||
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
|
||||
}
|
||||
|
||||
if (self_v_rot_swa) {
|
||||
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
|
|
@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
|
||||
}
|
||||
|
||||
if (inp_attn->self_k_rot_swa) {
|
||||
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
|
||||
}
|
||||
|
||||
if (inp_attn->self_v_rot_swa) {
|
||||
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
|
|
@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
if (inp->self_k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
|
||||
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
|
||||
|
||||
if (k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
|
||||
if (k_cur) {
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
|
||||
}
|
||||
}
|
||||
if (inp->self_v_rot) {
|
||||
if (v_rot) {
|
||||
if (v_cur) {
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
const auto * mctx_iswa = inp->mctx;
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||
|
||||
// optionally store to KV cache
|
||||
|
|
@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
|
||||
if (v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
|
||||
|
||||
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue