support smt and glm4

This commit is contained in:
djw 2025-07-24 08:40:58 +00:00
parent 1677e90092
commit b66d96db97
18 changed files with 3519 additions and 16 deletions

View file

@ -9,6 +9,8 @@ from torch import nn
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from ktransformers.models.modeling_smallthinker import SmallthinkerAttention
from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
@ -454,4 +456,231 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
return final_attention_output
return final_attention_output
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
freqs_cis: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
position_ids: torch.Tensor = None,
):
if self.use_qk_norm:
raise NotImplementedError("use_qk_norm is not implemented yet")
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
print(query_states.shape)
print(key_states.shape)
print(cos.shape)
print(sin.shape)
"""
if freqs_cis:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
freqs_cis: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
position_ids: torch.Tensor = None,
):
if self.use_qk_norm:
raise NotImplementedError("use_qk_norm is not implemented yet")
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
print(query_states.shape)
print(key_states.shape)
print(cos.shape)
print(sin.shape)
"""
if freqs_cis:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def apply_rotary_pos_emb(
self,
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
freqs_cis: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
position_ids: torch.Tensor = None,
):
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
# cos, sin = freqs_cis
"""
print(query_states.shape)
print(key_states.shape)
print(cos.shape)
print(sin.shape)
"""
if freqs_cis:
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
return attn_output