Update gate.py

This commit is contained in:
Atream 2025-03-19 16:14:54 +08:00 committed by GitHub
parent 6ca233cca3
commit b453333f60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
routed_scaling_factor: float = 1.0,
scoring_func: str = "sigmoid", scoring_func: str = "sigmoid",
e_score_correction_bias: Optional[torch.Tensor] = None): e_score_correction_bias: Optional[torch.Tensor] = None):
@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
score_mask = group_mask.unsqueeze(-1).expand( score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group, num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
float("-inf")) # [n, e] #float("-inf")) # [n, e]
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
dim=-1, dim=-1,
sorted=False) sorted=False)
if renormalize: if topk > 1 and renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights = topk_weights / denominator
topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
return topk_ids.to(torch.long), topk_weights.to(torch.float32) return topk_ids.to(torch.long), topk_weights.to(torch.float32)
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase): class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
@ -204,6 +206,7 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
self.is_windows = os.name == 'nt' self.is_windows = os.name == 'nt'
self.use_quant = use_quant self.use_quant = use_quant
if not self.is_windows and use_quant: if not self.is_windows and use_quant:
print("injecting gate_linear")
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device) self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp", self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
gguf_loader, config, self.gate_linear, #orig_module gguf_loader, config, self.gate_linear, #orig_module
@ -219,14 +222,13 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
### compute gating score ### compute gating score
hidden_states = hidden_states.view(-1, h) hidden_states = hidden_states.view(-1, h)
if self.use_quant: if self.use_quant:
logits = self.gate_linear.forward(logits) logits = self.gate_linear.forward(hidden_states)
else: else:
logits = F.linear( logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None hidden_states.type(torch.float32), self.weight.type(torch.float32), None
) )
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group,
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.topk_group, self.routed_scaling_factor, "sigmoid", self.e_score_correction_bias)
self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device if device is None: device = self.device