mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
support smt and glm4
This commit is contained in:
parent
1677e90092
commit
b66d96db97
18 changed files with 3519 additions and 16 deletions
|
@ -729,6 +729,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
|||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
|
@ -1248,6 +1250,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
|
||||
if prefill_op == 'None':
|
||||
prefill_op = None
|
||||
if generate_op == 'None':
|
||||
generate_op = None
|
||||
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
|
@ -1464,6 +1472,264 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
|||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):
|
||||
def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states)
|
||||
else:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input, bsz_tensor)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states, bsz_tensor)
|
||||
|
||||
router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)
|
||||
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
if self.moe_primary_router_apply_softmax:
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
else:
|
||||
routing_weights = F.sigmoid(router_logits)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
# y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
||||
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue