mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
Update gate.py
This commit is contained in:
parent
6ca233cca3
commit
b453333f60
1 changed files with 11 additions and 9 deletions
|
@ -132,6 +132,7 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
scoring_func: str = "sigmoid",
|
scoring_func: str = "sigmoid",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||||
|
|
||||||
|
@ -163,8 +164,8 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||||
score_mask = group_mask.unsqueeze(-1).expand(
|
score_mask = group_mask.unsqueeze(-1).expand(
|
||||||
num_token, num_expert_group,
|
num_token, num_expert_group,
|
||||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
|
||||||
float("-inf")) # [n, e]
|
#float("-inf")) # [n, e]
|
||||||
|
|
||||||
if e_score_correction_bias is not None:
|
if e_score_correction_bias is not None:
|
||||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||||
|
@ -176,9 +177,10 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
sorted=False)
|
sorted=False)
|
||||||
|
|
||||||
if renormalize:
|
if topk > 1 and renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
|
topk_weights = topk_weights / denominator
|
||||||
|
topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
|
||||||
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
|
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
|
||||||
|
|
||||||
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
||||||
|
@ -204,6 +206,7 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
||||||
self.is_windows = os.name == 'nt'
|
self.is_windows = os.name == 'nt'
|
||||||
self.use_quant = use_quant
|
self.use_quant = use_quant
|
||||||
if not self.is_windows and use_quant:
|
if not self.is_windows and use_quant:
|
||||||
|
print("injecting gate_linear")
|
||||||
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
|
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
|
||||||
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
|
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
|
||||||
gguf_loader, config, self.gate_linear, #orig_module
|
gguf_loader, config, self.gate_linear, #orig_module
|
||||||
|
@ -219,14 +222,13 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
||||||
### compute gating score
|
### compute gating score
|
||||||
hidden_states = hidden_states.view(-1, h)
|
hidden_states = hidden_states.view(-1, h)
|
||||||
if self.use_quant:
|
if self.use_quant:
|
||||||
logits = self.gate_linear.forward(logits)
|
logits = self.gate_linear.forward(hidden_states)
|
||||||
else:
|
else:
|
||||||
logits = F.linear(
|
logits = F.linear(
|
||||||
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
||||||
)
|
)
|
||||||
|
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group,
|
||||||
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob,
|
self.topk_group, self.routed_scaling_factor, "sigmoid", self.e_score_correction_bias)
|
||||||
self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
|
|
||||||
|
|
||||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||||
if device is None: device = self.device
|
if device is None: device = self.device
|
||||||
|
|
Loading…
Add table
Reference in a new issue