support smt and qlm4

This commit is contained in:
djw 2025-07-25 12:48:51 +00:00
parent 712ad1fa3c
commit 48bc6185b5
9 changed files with 65 additions and 74 deletions

View file

@ -1840,31 +1840,13 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if bsz_tensor is None:
router_logits = self.gate(hidden_states)
else:
router_logits = self.gate(hidden_states, bsz_tensor)
if router_logits.device.type == "xpu":
# TODO: support self.moe_primary_router_apply_softmax False case
from ipex_llm.transformers.models.common import moe_softmax_topk
selected_experts, routing_weights = moe_softmax_topk(
router_logits.half(), self.top_k, self.norm_topk_prob
)
else:
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
@ -1873,29 +1855,29 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
y.resize_(*orig_shape)
return y
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
# y_ = (
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
# )
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, selected_experts, routing_weights)
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
# y += y_
y += y_
return y
@torch.no_grad()