mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
43 lines
No EOL
1.8 KiB
Python
43 lines
No EOL
1.8 KiB
Python
import torch
|
|
import torch_npu
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ktransformers.operators.gate import KMoEGate
|
|
from ktransformers.util import utils
|
|
|
|
|
|
class KDeepseekV3GateA2(KMoEGate):
|
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
|
|
device = utils.CUR_DEVICE
|
|
if device is None:
|
|
device = self.device
|
|
if w is None:
|
|
w = self.load_weights(device=device)
|
|
|
|
if isinstance(w, dict):
|
|
self.weight_type = w["weight_type"]
|
|
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
|
self.orig_module.weight = nn.Parameter(w["weight"])
|
|
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
|
else:
|
|
raise ValueError("Invalid weight type")
|
|
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32))
|
|
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32))
|
|
|
|
def forward(self, hidden_states) -> torch.Tensor:
|
|
h = hidden_states.shape[-1]
|
|
# compute gating score
|
|
hidden_states = hidden_states.view(-1, h)
|
|
logits = F.linear(hidden_states.type(torch.float32), self.weight, None)
|
|
topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k(
|
|
logits,
|
|
k=self.top_k,
|
|
bias=self.e_score_correction_bias,
|
|
k_group=self.topk_group,
|
|
group_count=self.n_group,
|
|
group_select_mode=1,
|
|
renorm=0,
|
|
norm_type=1,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
eps=float(1e-20))
|
|
return topk_idx.type(torch.int64), topk_weight |