support npu

This commit is contained in:
djw 2025-07-21 04:05:15 +00:00
parent 1677e90092
commit dd0e41b3b8
14 changed files with 1453 additions and 5 deletions

View file

@ -0,0 +1,43 @@
import torch
import torch_npu
import torch.nn as nn
import torch.nn.functional as F
from ktransformers.operators.gate import KMoEGate
from ktransformers.util import utils
class KDeepseekV3GateA2(KMoEGate):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
device = utils.CUR_DEVICE
if device is None:
device = self.device
if w is None:
w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32))
def forward(self, hidden_states) -> torch.Tensor:
h = hidden_states.shape[-1]
# compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states.type(torch.float32), self.weight, None)
topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k(
logits,
k=self.top_k,
bias=self.e_score_correction_bias,
k_group=self.topk_group,
group_count=self.n_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=self.routed_scaling_factor,
eps=float(1e-20))
return topk_idx.type(torch.int64), topk_weight