From b453333f606bc6860fe2beab5415c50b01d304af Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Wed, 19 Mar 2025 16:14:54 +0800 Subject: [PATCH] Update gate.py --- ktransformers/operators/gate.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 271a144..dc93c96 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + routed_scaling_factor: float = 1.0, scoring_func: str = "sigmoid", e_score_correction_bias: Optional[torch.Tensor] = None): @@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor, score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) + #float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] @@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor, dim=-1, sorted=False) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - + if topk > 1 and renormalize: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights = topk_weights / denominator + topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor return topk_ids.to(torch.long), topk_weights.to(torch.float32) class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): @@ -204,6 +206,7 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): self.is_windows = os.name == 'nt' self.use_quant = use_quant if not self.is_windows and use_quant: + print("injecting gate_linear") self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device) self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp", gguf_loader, config, self.gate_linear, #orig_module @@ -219,14 +222,13 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): ### compute gating score hidden_states = hidden_states.view(-1, h) if self.use_quant: - logits = self.gate_linear.forward(logits) + logits = self.gate_linear.forward(hidden_states) else: logits = F.linear( hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) - - return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, - self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias) + return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group, + self.topk_group, self.routed_scaling_factor, "sigmoid", self.e_score_correction_bias) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device