add XPU support for qwen3moe local chat

This commit is contained in:
rnwang04 2025-05-21 18:33:41 +08:00
parent 25893366b6
commit adc0906967
9 changed files with 223 additions and 25 deletions

View file

@ -1421,19 +1421,28 @@ class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return final_out
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
router_logits = self.gate(hidden_states, bsz_tensor)
if bsz_tensor is None:
router_logits = self.gate(hidden_states)
else:
router_logits = self.gate(hidden_states, bsz_tensor)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
if router_logits.device.type == "xpu":
from ipex_llm.transformers.models.common import moe_softmax_topk
selected_experts, routing_weights = moe_softmax_topk(
router_logits.half(), self.top_k, self.norm_topk_prob
)
else:
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)