mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
support npu
This commit is contained in:
parent
1677e90092
commit
dd0e41b3b8
14 changed files with 1453 additions and 5 deletions
43
ktransformers/operators/ascend/ascend_gate.py
Normal file
43
ktransformers/operators/ascend/ascend_gate.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
import torch_npu
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ktransformers.operators.gate import KMoEGate
|
||||
from ktransformers.util import utils
|
||||
|
||||
|
||||
class KDeepseekV3GateA2(KMoEGate):
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
|
||||
device = utils.CUR_DEVICE
|
||||
if device is None:
|
||||
device = self.device
|
||||
if w is None:
|
||||
w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32))
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
h = hidden_states.shape[-1]
|
||||
# compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight, None)
|
||||
topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k(
|
||||
logits,
|
||||
k=self.top_k,
|
||||
bias=self.e_score_correction_bias,
|
||||
k_group=self.topk_group,
|
||||
group_count=self.n_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
eps=float(1e-20))
|
||||
return topk_idx.type(torch.int64), topk_weight
|
Loading…
Add table
Add a link
Reference in a new issue