mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-12 08:09:42 +00:00
support glm4moe
This commit is contained in:
parent
1677e90092
commit
d03d92ba53
31 changed files with 2265 additions and 74 deletions
|
@ -9,6 +9,8 @@ from torch import nn
|
|||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerAttention
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
@ -454,4 +456,191 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
|
||||
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
|
||||
return final_attention_output
|
||||
return final_attention_output
|
||||
|
||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
cos, sin = freqs_cis
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
|
||||
|
||||
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
||||
unsqueeze_dim=2
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Keep half or full tensor for later concatenation
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
rotary_dim = cos.shape[-1]
|
||||
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
||||
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
||||
|
||||
# Apply rotary embeddings on the first half or full tensor
|
||||
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
||||
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
||||
|
||||
# Concatenate back to full shape
|
||||
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
||||
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
||||
return q_embed, k_embed
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
|
||||
if self.use_qk_norm:
|
||||
query_states = self.q_norm(query_states, bsz_tensors)
|
||||
key_states = self.k_norm(key_states, bsz_tensors)
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis is not None:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
Loading…
Add table
Add a link
Reference in a new issue