mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
72 lines
No EOL
2.7 KiB
Python
72 lines
No EOL
2.7 KiB
Python
import torch
|
|
import torch_npu
|
|
|
|
from ktransformers.util.ascend.ascend_utils import allreduce_wrapper
|
|
from ktransformers.util.utils import CUR_DEVICE
|
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
|
|
|
|
|
class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP):
|
|
@allreduce_wrapper
|
|
def forward(self, x, is_prefill=None, use_cuda_graph=False):
|
|
original_dtype = x.dtype
|
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
|
dynamic_scale = dynamic_scale.view(-1)
|
|
gate_x = torch_npu.npu_quant_matmul(
|
|
quant_out,
|
|
self.orig_module.gate_proj.weight,
|
|
self.orig_module.gate_proj.weight_scale,
|
|
pertoken_scale=dynamic_scale,
|
|
bias=None,
|
|
output_dtype=original_dtype,
|
|
)
|
|
up_x = torch_npu.npu_quant_matmul(
|
|
quant_out,
|
|
self.orig_module.up_proj.weight,
|
|
self.orig_module.up_proj.weight_scale,
|
|
pertoken_scale=dynamic_scale,
|
|
bias=None,
|
|
output_dtype=original_dtype,
|
|
)
|
|
down_x = self.act_fn(gate_x) * up_x
|
|
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
|
|
down_dynamic_scale = down_dynamic_scale.view(-1)
|
|
down_proj = torch_npu.npu_quant_matmul(
|
|
down_quant_out,
|
|
self.orig_module.down_proj.weight,
|
|
self.orig_module.down_proj.weight_scale,
|
|
pertoken_scale=down_dynamic_scale,
|
|
bias=None,
|
|
output_dtype=original_dtype,
|
|
)
|
|
return down_proj
|
|
|
|
|
|
class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP):
|
|
@allreduce_wrapper
|
|
def forward(self, x, is_prefill=None, use_cuda_graph=False):
|
|
original_dtype = x.dtype
|
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
|
dynamic_scale = dynamic_scale.view(-1)
|
|
gate_up_x = torch_npu.npu_quant_matmul(
|
|
quant_out,
|
|
self.orig_module.gate_proj.weight,
|
|
self.orig_module.gate_proj.weight_scale,
|
|
pertoken_scale=dynamic_scale,
|
|
bias=None,
|
|
output_dtype=original_dtype,
|
|
)
|
|
down_x = torch_npu.npu_swiglu(gate_up_x, -1)
|
|
|
|
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
|
|
down_dynamic_scale = down_dynamic_scale.view(-1)
|
|
down_proj = torch_npu.npu_quant_matmul(
|
|
down_quant_out,
|
|
self.orig_module.down_proj.weight,
|
|
self.orig_module.down_proj.weight_scale,
|
|
pertoken_scale=down_dynamic_scale,
|
|
bias=None,
|
|
output_dtype=original_dtype,
|
|
)
|
|
return down_proj |