support glm4moe

This commit is contained in:
djw 2025-07-25 17:22:20 +00:00
parent 1677e90092
commit d03d92ba53
31 changed files with 2265 additions and 74 deletions

View file

@ -9,6 +9,8 @@ from torch import nn
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from ktransformers.models.modeling_smallthinker import SmallthinkerAttention
from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
@ -454,4 +456,191 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
return final_attention_output
return final_attention_output
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
freqs_cis: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
position_ids: torch.Tensor = None,
):
if self.use_qk_norm:
raise NotImplementedError("use_qk_norm is not implemented yet")
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
print(query_states.shape)
print(key_states.shape)
print(cos.shape)
print(sin.shape)
"""
if freqs_cis:
cos, sin = freqs_cis
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(
self,
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
unsqueeze_dim=2
) -> Tuple[torch.Tensor, torch.Tensor]:
# Keep half or full tensor for later concatenation
cos = freqs_cis[0]
sin = freqs_cis[1]
rotary_dim = cos.shape[-1]
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
freqs_cis: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
position_ids: torch.Tensor = None,
):
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
if self.use_qk_norm:
query_states = self.q_norm(query_states, bsz_tensors)
key_states = self.k_norm(key_states, bsz_tensors)
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
print(query_states.shape)
print(key_states.shape)
print(cos.shape)
print(sin.shape)
"""
if freqs_cis is not None:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output