mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 00:29:59 +00:00
support smt and glm4
This commit is contained in:
parent
1677e90092
commit
b66d96db97
18 changed files with 3519 additions and 16 deletions
|
@ -26,6 +26,8 @@ from ktransformers.operators.base_operator import BaseInjectedModule
|
|||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding
|
||||
import torch
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
|
||||
|
@ -437,4 +439,93 @@ class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
|||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config,
|
||||
device = self.generate_device,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
||||
|
||||
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config,
|
||||
device = self.generate_device,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
|
@ -9,6 +9,8 @@ from torch import nn
|
|||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerAttention
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
@ -454,4 +456,231 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
|
||||
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
|
||||
return final_attention_output
|
||||
return final_attention_output
|
||||
|
||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
|
@ -729,6 +729,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
|||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
|
@ -1248,6 +1250,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
|
||||
if prefill_op == 'None':
|
||||
prefill_op = None
|
||||
if generate_op == 'None':
|
||||
generate_op = None
|
||||
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
|
@ -1464,6 +1472,264 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
|||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):
|
||||
def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states)
|
||||
else:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input, bsz_tensor)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states, bsz_tensor)
|
||||
|
||||
router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)
|
||||
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
if self.moe_primary_router_apply_softmax:
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
else:
|
||||
routing_weights = F.sigmoid(router_logits)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
# y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
||||
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
|
|
|
@ -212,4 +212,5 @@ class KMoEGateIPEXLLM(KMoEGate):
|
|||
topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,
|
||||
self.n_group, self.topk_group, self.top_k,
|
||||
self.norm_topk_prob, self.routed_scaling_factor)
|
||||
return topk_idx, topk_weight.to(x.dtype)
|
||||
return topk_idx, topk_weight.to(x.dtype)
|
||||
|
||||
|
|
|
@ -28,6 +28,8 @@ import torch.nn as nn
|
|||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
if not torch.xpu.is_available():
|
||||
|
@ -164,6 +166,94 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
|
|||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
bsz, hidden_size = x.shape
|
||||
x = x.view(-1, self.orig_module.hidden_size)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
out = out.view(bsz, hidden_size)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
bsz, hidden_size = x.shape
|
||||
x = x.view(-1, self.orig_module.hidden_size)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
out = out.view(bsz, hidden_size)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
|
||||
class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
|
|
|
@ -5,6 +5,8 @@ from transformers import PretrainedConfig
|
|||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeMLP
|
||||
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
|
@ -32,6 +34,37 @@ class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):
|
|||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
||||
|
||||
|
||||
class KSmallthinkerDenseMlpBlock(SmallthinkerDenseMlpBlock, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down(nn.functional.relu(self.gate(x, bsz_tensor)) * self.up(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
||||
|
||||
class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
Loading…
Add table
Add a link
Reference in a new issue