support smt and qlm4

This commit is contained in:
djw 2025-07-25 12:48:51 +00:00
parent 712ad1fa3c
commit 48bc6185b5
9 changed files with 65 additions and 74 deletions

View file

@ -568,15 +568,31 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
def apply_rotary_pos_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
unsqueeze_dim=2
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# Keep half or full tensor for later concatenation
cos = freqs_cis[0]
sin = freqs_cis[1]
rotary_dim = cos.shape[-1]
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def forward(self,
hidden_states: torch.Tensor,
@ -587,18 +603,20 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
position_ids: torch.Tensor = None,
):
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
if self.use_qk_norm:
query_states = self.q_norm(query_states, bsz_tensors)
key_states = self.k_norm(key_states, bsz_tensors)
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
@ -607,14 +625,14 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
print(cos.shape)
print(sin.shape)
"""
if freqs_cis:
if freqs_cis is not None:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
@ -623,6 +641,6 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output