support qwen3

This commit is contained in:
djw 2025-04-28 14:05:24 +00:00
parent 3f9bbf1181
commit 0da3792b27
5 changed files with 9 additions and 3 deletions

View file

@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
):
q_len, _ = hidden_states.size()
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors)
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors)
bsz_tensors_q = bsz_tensors * self.num_heads
bsz_tensors_kv = bsz_tensors * self.num_key_value_heads
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q)
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv)
value_states = self.v_proj(hidden_states, bsz_tensors)