diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index d3aa215..02848ac 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -125,7 +125,7 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): # adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071 # This is used by the Deepseek-V2 and Deepseek-V3 model -#@torch.compile(dynamic=True) +@torch.compile(dynamic=True) def grouped_topk(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, @@ -225,9 +225,8 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) - return grouped_topk(hidden_states, logits, - self.top_k, self.norm_topk_prob, - self.n_group, self.topk_group) + return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, + self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device