mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
support smt and qlm4
This commit is contained in:
parent
712ad1fa3c
commit
48bc6185b5
9 changed files with 65 additions and 74 deletions
|
@ -1840,31 +1840,13 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
|||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
@ -1873,29 +1855,29 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
|||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue