mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
Update gate.py
This commit is contained in:
parent
6ca233cca3
commit
b453333f60
1 changed files with 11 additions and 9 deletions
|
@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
scoring_func: str = "sigmoid",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
|
||||
|
@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
|
||||
#float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
|
@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||
dim=-1,
|
||||
sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if topk > 1 and renormalize:
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights = topk_weights / denominator
|
||||
topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
|
||||
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
|
||||
|
||||
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
||||
|
@ -204,6 +206,7 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
|||
self.is_windows = os.name == 'nt'
|
||||
self.use_quant = use_quant
|
||||
if not self.is_windows and use_quant:
|
||||
print("injecting gate_linear")
|
||||
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
|
||||
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
|
||||
gguf_loader, config, self.gate_linear, #orig_module
|
||||
|
@ -219,14 +222,13 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
|||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
if self.use_quant:
|
||||
logits = self.gate_linear.forward(logits)
|
||||
logits = self.gate_linear.forward(hidden_states)
|
||||
else:
|
||||
logits = F.linear(
|
||||
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
||||
)
|
||||
|
||||
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob,
|
||||
self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
|
||||
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group,
|
||||
self.topk_group, self.routed_scaling_factor, "sigmoid", self.e_score_correction_bias)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
|
|
Loading…
Add table
Reference in a new issue