support npu

This commit is contained in:
djw 2025-07-21 04:05:15 +00:00
parent 1677e90092
commit dd0e41b3b8
14 changed files with 1453 additions and 5 deletions

5
.gitignore vendored
View file

@ -26,4 +26,7 @@ ktransformers/tests/chat_txt.txt
mmlu_result*
ktransformers/ktransformers_ext/cuda_musa/
test_prompt.txt
csrc/demo
csrc/demo
CMakeFiles
kvc2/
sched/

View file

@ -42,6 +42,7 @@ option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA"
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@ -306,6 +307,17 @@ elseif (UNIX)
endif()
elseif (KTRANSFORMERS_USE_XPU)
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
elseif (KTRANSFORMERS_USE_NPU)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
else()
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")

View file

@ -9,7 +9,7 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU)
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"

View file

@ -0,0 +1,165 @@
# 部署
## 物理机安装
部署满血版DeepseekV3需要机器物理内存能够存放下全部路由专家的权重约400GB。
目前支持的NPU型号**800I A2**。
在技术人员的支持下完成硬件安装。
## 系统安装
根据网页[昇腾兼容性查询助手](https://www.hiascend.com/hardware/compatibility)查询选用系统Ubuntu 22.04 for aarch64内核5.15.0-25-generic并禁止系统自动更新。系统镜像获取链接[ubuntu-old-releases](https://mirrors.aliyun.com/oldubuntu-releases/releases/22.04)。
## HDK安装
选择[Ascend HDK 25.0.RC1](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=32&cann=8.1.RC1.beta1&driver=Ascend+HDK+25.0.RC1)进行安装,安装方式参考[昇腾社区HDK安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1beta1/softwareinst/instg/instg_0005.html?Mode=PmIns&InstallType=local&OS=Ubuntu&Software=cannToolKit)。
## Conda部署
建议按照最新[Installation Guide - kTransformers](https://kvcache-ai.github.io/ktransformers/en/install.html)部署开发环境此处注意Python版本要求3.11其他版本未验证arm平台不需要安装cpufeature包。
安装conda/miniconda
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
bash ~/Miniconda3-latest-Linux-aarch64.sh
```
部署Python环境
```bash
conda create -n py311 python=3.11
conda activate py311
conda install -c conda-forge libstdcxx-ng # 安装`GLIBCXX-3.4.32`
pip3 install numpy==1.26.4 # 适配torch/torch_npu
pip3 install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
pip3 install packaging ninja transformers==4.43.2 fire protobuf attrs decorator cloudpickle ml-dtypes scipy tornado absl-py psutil
#pip3 install cpufeature # only for x86
```
## CANN安装
选择[CANN 8.1.RC1.beta1](https://www.hiascend.com/developer/download/community/result?from=firmware&product=4&model=32&cann=8.1.RC1.beta1)进行安装,安装方式参考[昇腾社区CANN安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1beta1/softwareinst/instg/instg_0007.html?Mode=PmIns&InstallType=local&OS=Ubuntu&Software=cannToolKit)。
需要安装ToolKitKernel和NNAL。
## torch_npu(op-plugin)安装
获取最新的仓库代码:[op-plugin Gitee](https://gitee.com/ascend/op-plugin)
由于涉及新增算子公网pypi内提供的torch_npu暂时无法直接使用需要使用适配过的op-plugin来编译生成所需的torch_npu包目前还无法从公网获取。 # TODO
在访问github和gitee的网络通畅时执行下述代码完成编译和安装torch_npu
```bash
cd op-plugin
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
source /usr/local/Ascend/nnal/atb/set_env.sh # 以实际NNAL安装路径为准
bash install.sh --python=3.11 --pytorch=v2.3.1-7.0.0 # 实际生成的torch_npu whl安装文件在{op-plugin项目地址}/dist中
```
## 权重准备
目前,为了满足性能和精度的要求,我们需要准备两份权重,并使用提供的权重合并脚本对权重进行合并,最终只会使用合并后的权重。
Q4权重[DeepSeek-R1-Q4_K_M](https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/files)
W8A8权重[DeepSeek-R1-W8A8](https://modelers.cn/models/MindSpore-Lab/DeepSeek-R1-W8A8/tree/main)
使用[merge_safetensor_gguf.py](../../merge_tensors/merge_safetensor_gguf.py)来合并Q4和W8A8权重
```bash
python merge_safetensor_gguf.py --safetensor_path /mnt/weights/DeepSeek-R1-Q4_K_M --gguf_path /mnt/weights/DeepSeek-R1-W8A8 --output_path /mnt/weights/DeepSeek-R1-q4km-w8a8 --safetensors_format w8a8
```
## 图下沉部署
图下沉所需的二进制文件随仓库给出:[ktransformers/util/npu_graph_so](../../ktransformers/util/npu_graph_so)。
部署图下沉功能需要做相关文件替换以arm平台为例
```bash
mv /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so.bak
cp ktransformers/util/npu_graph_so/arm/libruntime.so /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so
```
开启图下沉功能,需要添加如下环境变量:
```bash
export CAPTURE_PLUGIN_PATH=ktransformers/util/npu_graph_so/arm
export TASK_QUEUE_ENABLE=0 # 保证算子下发顺序有序
```
## kTransformers部署
将项目文件部署到机器上:
- 对于arm平台注释掉`./requirements-local_chat.txt`中的`cpufeature`
- 对于arm平台做如下替换
```bash
cp ./for_arm/CMakeLists.txt ./csrc/ktransformers_ext/CMakeLists.txt
cp ./for_arm/iqk_mul_mat.inc ./third_party/llamafile/iqk_mul_mat.inc
cp ./for_arm/sgemm.cpp ./third_party/llamafile/sgemm.cpp
cp ./for_arm/tinyblas_cpu_sgemm.inc ./third_party/llamafile/tinyblas_cpu_sgemm.inc
cp ./for_arm/setup.py ./setup.py
```
- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`以实际CANN-TOOLKIT安装路径为准
- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。
- 执行`bash install.sh`,等待安装完成。
此处给出示例local_chat的启动脚本由于使用了相对路径需将该脚本放至项目的根路径下
```bash
#!/bin/bash
export CAPTURE_PLUGIN_PATH=ktransformers/util/npu_graph_so/arm
export USE_MERGE=0
export INF_NAN_MODE_FORCE_DISABLE=1
export TASK_QUEUE_ENABLE=0
#export PROF_DECODE=1
#export PROF_PREFILL=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
torchrun --nproc_per_node 1 \
--master_port 25565 \
-m ktransformers.local_chat \
--cpu_infer 20 \
--model_path /mnt/weights/DeepSeek-R1-q4km-w8a8 \
--gguf_path /mnt/weights/DeepSeek-R1-q4km-w8a8 \
--optimize_config_path ./ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-800IA2-npu.yaml \
--use_cuda_graph True \
--max_new_tokens 500 \
--tp 1
```
相关参数说明:
- `--model_path`kTransformers原生参数str此处用来指定合并后的模型文件路径
- `--gguf_path`kTransformers原生参数str此处用来指定合并后的模型文件路径
- `--cpu_infer`kTransformers原生参数int用来控制CPU侧实际worker线程数非必选
- `--optimize_config_path`kTransformers原生参数str用来指定所用的模型优化配置文件需要注意相对路径的使用此处为**必选**
- `--use_cuda_graph`kTransformers原生参数bool为True表示开启图下沉为False表示关闭图下沉非必选
- `--max_new_tokens`kTransformers原生参数int当统计到输出的tokens数量达到该值时会直接中止输出非必选
- `--tp`新增参数int用于开启tensor model parallel功能目前local_chat只支持tp大小与ws大小相同不支持local_chat使用多dp非必选
# 其他问题
## 可能存在的其他依赖问题
ImportError: libhccl.so: cannot open shared object file: No such file or directory
```bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
```
ImportError: libascend_hal.so: cannot open shared object file: No such file or directory
```bash
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH # 以实际Driver安装路径为准
```

21
install_for_npu.sh Normal file
View file

@ -0,0 +1,21 @@
#!/bin/bash
set -e
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# clear build dirs
rm -rf build
rm -rf *.egg-info
rm -rf csrc/build
rm -rf csrc/ktransformers_ext/build
rm -rf csrc/ktransformers_ext/cuda/build
rm -rf csrc/ktransformers_ext/cuda/dist
rm -rf csrc/ktransformers_ext/cuda/*.egg-info
rm -rf ~/.ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
pip install -r ktransformers/server/requirements.txt
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
pip install third_party/custom_flashinfer/
echo "Installation completed successfully"

View file

@ -67,7 +67,7 @@ attn:
page_size: 256
chunk_size: 256
kvc2:
gpu_only: false
gpu_only: true
utilization_percentage: 1.0
cpu_memory_size_GB: 500
disk_path: /mnt/data/kvc

View file

@ -0,0 +1,467 @@
import os
import warnings
from typing import Optional, Tuple
import torch
import torch_npu
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, allreduce_wrapper
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import get_compute_capability, get_use_npu_graph, CUR_DEVICE
from ktransformers.util.vendors import device_manager, GPUVendor
from ktransformers.util import utils
def apply_rotary_pos_emb_fusion(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
class MatMulOps(object):
def execute(self, x_input):
"""
:param x, weight, quant_bia, deq_scale
:return:
"""
quant_out = x_input[0]
weight = x_input[1]
quant_bia = x_input[2]
deq_scale = x_input[3]
return [torch_npu.npu_quant_matmul(quant_out, weight.T, deq_scale, bias=quant_bia, output_dtype=torch.float16)]
class MatMulOpsAtb(object):
def execute(self, x_input):
"""
:param x, weight, quant_bia, deq_scale
:return:
"""
x = x_input[0]
weight = x_input[1]
quant_bia = x_input[2]
deq_scale = x_input[3]
target_shape = (x.shape[0], x.shape[-2], weight.shape[-2])
target_tensor = torch.zeros(target_shape, dtype=torch.float16, device=x.device)
torch_npu.torch_npu._npu_matmul_dequant(x, weight, quant_bia, deq_scale, target_tensor)
return [target_tensor]
class DynamicQuantOps(object):
"""
:param x, scale, offset
:return
"""
def execute(self, x_input):
out = torch.empty_like(x_input[0], dtype=torch.int8)
torch_npu._npu_quantize_per_tensor(x_input[0], x_input[1], x_input[2], out)
return [out]
class KDeepseekV2AttentionW8A8A2(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
attn_mask: Optional[torch.Tensor] = None
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
absorb_for_prefill: bool = False,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device,
**kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
self.mla_wrapper = None
tp = get_tensor_parallel_size()
if tp > 1:
self.num_heads //= tp
self.absorb_for_prefill = absorb_for_prefill
self.use_merge = os.getenv("USE_MERGE", "0")
if self.use_merge == "0":
print("--Use ATB FA-MLA and PA-MLA OP !--")
self.elewise_quant = DynamicQuantOps()
self.matmulDequant_operation = MatMulOpsAtb()
self.matmulDequant_operation_aclnn = MatMulOps()
elif self.use_merge == "1":
print("--Use torch npu FA OP !--")
else:
print("--Use default op! --")
@allreduce_wrapper
def forward_chunck(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0]
q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight,
self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0]
q_a_proj_out = self.q_a_layernorm(q_a_proj_out)
q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0]
q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight,
self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0]
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0]
compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight,
self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0]
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv_seq_len = k_pe.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin)
# update KV
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe = k_pe.transpose(1, 2) # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
compressed_kv = compressed_kv.unsqueeze(2) # compressed_kv [bsz, q_len, self.kv_lora_rank]
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv, k_pe = torch.split(
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:, :, :attention_mask.size(-1), :]
compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:, :, :attention_mask.size(-1), :]
weight_uk = self.q_absorb
weight_uv = self.out_absorb
# ATB-MLA-FA+PA
if self.use_merge == "0" and q_len != 1:
current_sqenLen = past_key_value.get_seq_length(self.layer_idx)
attention_mask = attention_mask[0, :, :, :current_sqenLen].squeeze(0).squeeze(0)
compressed_kv = compressed_kv[:, :, :current_sqenLen, :] # all KV until current chunk
k_pe = k_pe[:, :, :current_sqenLen, :]
k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1)
k_up = torch.matmul(compressed_kv, weight_uk.mT)
v_up = torch.matmul(compressed_kv, weight_uv)
qTensor = torch.cat((q_nope, q_pe), dim=-1).transpose(1, 2).contiguous().view(
bsz * q_len, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))
kTensor = torch.cat((k_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view(
bsz * current_sqenLen, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))
vTensor = v_up.transpose(1, 2).contiguous().view(bsz * current_sqenLen, self.num_heads, self.v_head_dim)
seq_len_data = [q_len] * bsz
seq_len = torch.tensor(seq_len_data, dtype=torch.int32, device=vTensor.device)
seq_len_host = torch.tensor(seq_len_data, dtype=torch.int32)
attn_output = torch.ones((qTensor.shape[0], qTensor.shape[1], vTensor.shape[-1]),
dtype=qTensor.dtype, device=vTensor.device)
torch_npu._npu_flash_attention_mla(qTensor, kTensor, vTensor, attention_mask, seq_len, seq_len_host,
self.softmax_scale, self.num_heads, self.num_heads, attn_output)
if attn_output.size() != (bsz * q_len, self.num_heads, self.v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.v_head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]
attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,
self.o_proj.quant_bias, self.o_proj.deq_scale])[0]
return attn_output, None, past_key_value
elif self.use_merge == "0" and q_len == 1:
return self.forward_paged(q_pe=q_pe,
q_nope=q_nope,
compressed_kv_with_k_pe=compressed_kv_with_k_pe,
past_key_value=past_key_value,
cache_position=cache_position)
if self.use_merge == "1":
k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1)
k_up = torch.matmul(compressed_kv, weight_uk.mT)
v_up = torch.matmul(compressed_kv, weight_uv)
qTensor = torch.cat((q_nope, q_pe), dim=-1)
kTensor = torch.cat((k_up, k_pe_repeated), dim=-1)
vTensor = torch.cat((v_up, k_pe_repeated), dim=-1)
if q_len != 1:
attn_output = torch_npu.npu_prompt_flash_attention(
qTensor, kTensor, vTensor,
num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD")
else:
attn_output = torch_npu.npu_incre_flash_attention(
qTensor, kTensor, vTensor,
num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD")
attn_output = attn_output[:, :, :, :self.v_head_dim]
else:
q_nope = torch.matmul(q_nope, self.q_absorb)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
compressed_kv = compressed_kv.squeeze(1)
"""
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
assert attention_mask is not None
"""
if attention_mask is not None:
"""
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
"""
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q_pe.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.matmul(attn_output, self.out_absorb)
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_paged(
self,
q_pe: torch.Tensor,
q_nope: torch.Tensor,
compressed_kv_with_k_pe: torch.Tensor,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, _, q_len, _ = q_nope.size()
q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb) # torch.Size([1, 128, 1, 512])
compressed_kv = compressed_kv_with_k_pe.permute(0, 2, 1, 3)
kvCache = compressed_kv[:, :, :, :self.kv_lora_rank].contiguous()
kRopeCache = compressed_kv[:, :, :, self.kv_lora_rank:].contiguous()
if get_use_npu_graph():
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
stream = npu_graph_runner.main_stream
if npu_graph_runner.past_key_value is None:
npu_graph_runner.past_key_value = past_key_value
if npu_graph_runner.workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope,
kvCache,
kvCache,
query_rope=q_pe,
key_rope=kRopeCache,
num_heads=self.num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=None,
scale=self.softmax_scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=past_key_value.page_table_list[self.layer_idx],
block_size=past_key_value.page_size,
actual_seq_lengths_kv=past_key_value.position
)
npu_graph_runner.workspace = workspace
attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=CUR_DEVICE)
softmax_lse = torch.empty(1, dtype=torch.float16, device=CUR_DEVICE)
npu_graph_runner.ifa_param.append((q_nope, kvCache, q_pe, kRopeCache, self.num_heads,
self.softmax_scale, self.layer_idx, attn_output, softmax_lse))
eventTmp = torch.npu.ExternalEvent()
npu_graph_runner.event.append(eventTmp)
eventTmp.wait(stream)
eventTmp.reset(stream)
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
kvCache,
kvCache,
workspace=npu_graph_runner.workspace,
query_rope=q_pe,
key_rope=kRopeCache,
num_heads=self.num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=None,
scale=self.softmax_scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=past_key_value.page_table_list[self.layer_idx],
block_size=past_key_value.page_size,
actual_seq_lengths_kv=past_key_value.position,
out=[attn_output, softmax_lse])
handle = torch.npu.graph_task_group_end(stream)
npu_graph_runner.handle.append(handle)
else:
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q_nope,
kvCache,
kvCache,
query_rope=q_pe,
key_rope=kRopeCache,
num_heads=self.num_heads,
num_key_value_heads=1,
input_layout="BNSD",
atten_mask=None,
scale=self.softmax_scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=past_key_value.page_table_list[self.layer_idx],
block_size=past_key_value.page_size,
actual_seq_lengths_kv=past_key_value.position,
)
attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb)
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]
attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,
self.o_proj.quant_bias, self.o_proj.deq_scale])[0]
return attn_output, None, past_key_value
def forward_windows(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
if q_len <= self.chunck_size:
return self.forward_chunck(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs
)
assert output_attentions is False, "output_attentions is not supported when using chunked attention"
attn_output = None
cur_idx = 0
while cur_idx < q_len:
if attention_mask is not None:
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
else:
# generate chunk_mask automatically.
self.attn_mask = \
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
if self.attn_mask is None \
else self.attn_mask
self.attn_mask[:, :, :, cur_idx:min(cur_idx + self.chunck_size, past_key_value.max_cache_len)] = \
-65504.0 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1) \
[:, :min(self.chunck_size, min(past_key_value.max_cache_len - cur_idx, self.chunck_size))]
self.attn_mask[:, :, :, cur_idx + self.chunck_size:] = -65504.0
self.attn_mask[:, :, :, :cur_idx] = 0
chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len - cur_idx))
cur_output, _, _ = self.forward_chunck(
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
chunk_mask,
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
past_key_value,
output_attentions,
use_cache,
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
**kwargs
)
cur_idx += self.chunck_size
if attn_output is None:
attn_output = cur_output
else:
attn_output = torch.cat((attn_output, cur_output), dim=-2)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
return self.forward_windows(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)

View file

@ -0,0 +1,192 @@
import bisect
import acl
import torch
import numpy as np
from torch import nn
from transformers import PretrainedConfig
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, get_tensor_parallel_group
from ktransformers.operators.experts import KExpertsCPU, KTransformersExperts, EXPERTS_MAP, KDeepseekV3MoE, cuda_graphs
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import CUR_DEVICE, get_use_npu_graph, InferenceState
class KExpertsCPUW8A8(KExpertsCPU):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cpu",
**kwargs
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.input_tensor_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.expert_ids_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
self.weights_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
self.output_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.bsz_tensor_cpu_graph = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if get_use_npu_graph():
self.cpu_infer.submit(self.moe.forward(self.expert_ids_cpu_graph.size(0),
self.expert_ids_cpu_graph.size(1),
self.expert_ids_cpu_graph.data_ptr(),
self.weights_cpu_graph.data_ptr(),
self.input_tensor_cpu_graph.data_ptr(),
self.output_cpu_graph.data_ptr(),
self.bsz_tensor_cpu_graph.data_ptr()))
self.cpu_infer.sync()
else:
if bsz_tensor is None:
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
org_type = input_tensor.dtype
input_tensor = input_tensor.contiguous().cpu()
input_tensor = input_tensor.to(torch.bfloat16)
expert_ids = expert_ids.contiguous().cpu()
weights = weights.contiguous().to(torch.float32).cpu()
bsz_tensor = bsz_tensor.contiguous().cpu()
output = torch.empty_like(input_tensor).contiguous()
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
self.cpu_infer.sync()
return output.to(org_type).to(device=CUR_DEVICE)
EXPERTS_MAP["KExpertsCPUW8A8"] = KExpertsCPUW8A8
class KTransformersExpertsW8A8(KTransformersExperts):
def forward(self, input_tensor, expert_ids, weights):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights)
else:
raise ValueError("load or set_inference_mode before forward")
class KDeepseekV3MoEW8A8(KDeepseekV3MoE):
def forward_tp(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
rank = torch.distributed.get_rank()
def share_experts_forward():
if self.config.n_shared_experts is not None:
return self.shared_experts(identity).squeeze(0)
if rank == 0:
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1)
if get_use_npu_graph():
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
event = torch.npu.Event()
event.record(npu_graph_runner.main_stream)
with torch.npu.stream(npu_graph_runner.update_stream):
event.wait(npu_graph_runner.update_stream)
y_ = share_experts_forward() if share_experts_forward is not None else None
event.record(npu_graph_runner.update_stream)
org_type = hidden_states.dtype
input_tensor = hidden_states.to(torch.bfloat16)
topk_weight = topk_weight.contiguous().to(torch.float32)
self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight)
self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True)
self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True)
self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True)
npu_graph_runner.launch_callback(
self.cpu_moe_kexperts,
self.moe_kexperts_param,
1, npu_graph_runner.stream)
output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True)
y = output_npu_graph.to(org_type)
event.wait(npu_graph_runner.main_stream)
else:
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)
y_ = share_experts_forward() if share_experts_forward is not None else None
y = y.view(*orig_shape).to(device=hidden_states.device)
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y = torch.zeros(orig_shape, dtype=torch.float16, device=CUR_DEVICE)
y_ = share_experts_forward() if share_experts_forward is not None else None
torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group())
if self.config.n_shared_experts is not None:
y += y_
return y
def forward(self, hidden_states):
tp_size = get_tensor_parallel_size()
world_size = torch.distributed.get_world_size()
if tp_size > 1 and world_size == tp_size:
return self.forward_tp(hidden_states)
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y_ = None
# only for generate phase
# if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and False:
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
def share_experts_forward():
if self.config.n_shared_experts is not None:
return self.shared_experts(identity).squeeze(0)
cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1)
if get_use_npu_graph():
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
event = torch.npu.Event()
event.record(npu_graph_runner.main_stream)
with torch.npu.stream(npu_graph_runner.update_stream):
event.wait(npu_graph_runner.update_stream)
y_ = share_experts_forward() if share_experts_forward is not None else None
event.record(npu_graph_runner.update_stream)
org_type = hidden_states.dtype
input_tensor = hidden_states.to(torch.bfloat16)
topk_weight = topk_weight.contiguous().to(torch.float32)
self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight)
self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True)
self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True)
self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True)
npu_graph_runner.launch_callback(
self.cpu_moe_kexperts,
self.moe_kexperts_param,
1, npu_graph_runner.stream)
output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True)
y = output_npu_graph.to(org_type)
event.wait(npu_graph_runner.main_stream)
else:
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)
y_ = share_experts_forward() if share_experts_forward is not None else None
y = y.view(*orig_shape).to(device=hidden_states.device)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def cpu_moe_kexperts(self, moe_kexperts_param) -> torch.Tensor:
x, topk_ids, topk_weight = moe_kexperts_param
self.moe_kexperts(x, topk_ids, topk_weight)
@torch.no_grad()
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

View file

@ -0,0 +1,43 @@
import torch
import torch_npu
import torch.nn as nn
import torch.nn.functional as F
from ktransformers.operators.gate import KMoEGate
from ktransformers.util import utils
class KDeepseekV3GateA2(KMoEGate):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
device = utils.CUR_DEVICE
if device is None:
device = self.device
if w is None:
w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32))
def forward(self, hidden_states) -> torch.Tensor:
h = hidden_states.shape[-1]
# compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states.type(torch.float32), self.weight, None)
topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k(
logits,
k=self.top_k,
bias=self.e_score_correction_bias,
k_group=self.topk_group,
group_count=self.n_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=self.routed_scaling_factor,
eps=float(1e-20))
return topk_idx.type(torch.int64), topk_weight

View file

@ -0,0 +1,38 @@
import torch
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util import utils
from ktransformers.util.custom_gguf import GGUFLoader
class KDeepseekV3RMSNormW8A8(BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "npu",
generate_device: str = "npu",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.weight = nn.Parameter(torch.ones(self.orig_module.hidden_size))
self.bias = nn.Parameter(torch.ones(self.orig_module.hidden_size))
self.variance_epsilon = self.orig_module.variance_epsilon
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
out = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + self.bias
return out.to(input_dtype)
def load(self):
self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(utils.CUR_DEVICE)
self.bias = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".bias").to(utils.CUR_DEVICE)
def unload(self):
if self.weight is not None:
self.weight = None
if self.bias is not None:
self.bias = None

View file

@ -0,0 +1,298 @@
from abc import abstractmethod
import torch
import torch_npu
import torch.distributed as dist
from torch import nn
from transformers import PretrainedConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.linear import KLinearBase, LINEAR_MAP
from ktransformers.util.ascend.ascend_utils import (
get_safetensors_cut_weight,
get_tensor_parallel_size,
get_tensor_parallel_group
)
from ktransformers.util import utils
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
class KLinearW8A8(KLinearBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
def load_weight(self, override_key: str | None = None, device: str | None = None):
if override_key is not None:
keys = override_key
else:
keys = [self.key]
fake_tensor = torch.tensor([1])
for key in keys:
if device is None:
device = utils.CUR_DEVICE
if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map:
if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
deq_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.deq_scale")
quant_bias = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.quant_bias")
input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale")
input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset")
tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset)
return tensors
elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
if key.endswith("ffn_gate_shexp"):
parts = key.split(".")
layer = parts[1]
gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight")
gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t()
up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight")
up_weight = get_safetensors_cut_weight(self.key, up_weight).t()
gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale")
gate_scale = get_safetensors_cut_weight(self.key, gate_scale)
up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale")
up_scale = get_safetensors_cut_weight(self.key, up_scale)
gate_up_weight = torch.cat((gate_weight, up_weight), 1)
gate_up_scale = torch.cat((gate_scale, up_scale), 0)
gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset")
up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset")
gate_up_offset = torch.cat((gate_offset, up_offset), 0)
tensors = (gate_up_weight, gate_up_scale, gate_up_offset)
elif key.endswith("ffn_up_shexp"):
return fake_tensor
else:
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale")
weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset")
tensors = (qweight, weight_scale, weight_offset)
return tensors
else:
weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
weight = get_safetensors_cut_weight(self.key, weight)
return weight
else:
raise FileNotFoundError(f"Weight file not found for key {key}")
@abstractmethod
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = "cuda"):
pass
@abstractmethod
def unload(self):
pass
class KLinearTorchW8A8A2(KLinearW8A8):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.weight = None
self.input_scale = None
self.input_offset = None
self.quant_bias = None
self.deq_scale = None
self.weight_scale = None
self.weight_offset = None
def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor:
tp = get_tensor_parallel_size()
if tp == 1:
out = torch.zeros((x.shape[0], x.shape[1], self.weight.shape[-1]), dtype=torch.float16, device=x.device)
torch_npu._npu_matmul_pp(x, self.weight, out)
else:
tp_size = get_tensor_parallel_size()
tp_group = get_tensor_parallel_group()
batch_size = x.shape[0]
seq_length = x.shape[1]
lm_sep_size = tp_size
lm_head_group = tp_group
gathered_list = [torch.empty_like(x) for _ in range(lm_sep_size)]
dist.all_gather(gathered_list, x, group=lm_head_group)
input_full = torch.stack(gathered_list, dim=0)
input_full = input_full.squeeze(dim=1)
torch_npu.npu_format_cast_(input_full, 2)
local_logits = torch.zeros((input_full.shape[0], input_full.shape[1], self.weight.shape[-1]),
dtype=torch.float16, device=input_full.device)
torch_npu._npu_matmul_pp(input_full, self.weight, local_logits)
local_logits_transpose = local_logits.transpose(2, 1).reshape(-1, batch_size * seq_length)
del local_logits
output_tensor = torch.empty_like(local_logits_transpose)
dist.all_to_all_single(output_tensor, local_logits_transpose, group=lm_head_group)
del local_logits_transpose
output_tensor = output_tensor.transpose(1, 0)
out = output_tensor.view(batch_size, seq_length, -1)
del output_tensor
return out
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
if device is None:
device = self.device
device = utils.CUR_DEVICE
if w is None:
w = self.load_weight()
if isinstance(w, nn.Parameter):
try:
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T.contiguous()
except:
self.weight = w.to(dtype=self.dtype).T.contiguous()
self.weight = self.weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
elif isinstance(w, tuple):
w_list = list(w)
if len(w_list) == 3:
self.weight = w_list[0]
self.weight_scale = w_list[1].view(-1)
self.weight_offset = w_list[2]
self.weight = self.weight.to(utils.CUR_DEVICE)
self.weight_scale = self.weight_scale.to(utils.CUR_DEVICE)
if self.key.endswith("ffn_gate_shexp") is not True:
self.weight = get_safetensors_cut_weight(self.key, self.weight).t()
weight_scale = get_safetensors_cut_weight(self.key, self.weight_scale)
self.weight_scale = weight_scale.clone()
del weight_scale
self.weight_offset = self.weight_offset.to(utils.CUR_DEVICE)
else:
for i in range(len(w_list)):
w_list[i] = get_safetensors_cut_weight(self.key, w_list[i])
w_list[i] = w_list[i].to(utils.CUR_DEVICE)
self.weight = w_list[0]
self.deq_scale = w_list[1]
self.quant_bias = w_list[2]
if "attn_output" in self.key or "ffn_down" in self.key:
if torch.distributed.get_rank(get_tensor_parallel_group()) != 0:
self.quant_bias = torch.zeros_like(self.quant_bias, dtype=self.quant_bias.dtype, device=self.quant_bias.device)
self.input_scale = w_list[3]
self.input_offset = w_list[4]
elif isinstance(w, torch.Tensor):
self.weight = w.T.contiguous()
self.weight.to(device)
if "kv_b" not in self.key:
self.weight = self.weight.to(device)
torch_npu.npu_format_cast_(self.weight, 29)
else:
raise ValueError(f"Invalid weight type {self.key=} {type(w)=}")
def unload(self):
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
self.input_scale = None
self.input_offset = None
self.quant_bias = None
self.deq_scale = None
self.weight_scale = None
self.weight_offset = None
LINEAR_MAP["KLinearTorchW8A8A2"] = KLinearTorchW8A8A2
class KTransformersLinearW8A8A2(BaseInjectedModule, KLinearW8A8):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
generate_device: str = "cuda",
generate_op: str | None = "KLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str | None = "KLinearTorch",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KLinearW8A8.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators
if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else:
self.prefill_linear = None
if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else:
self.generate_linear = None
self.mode = InferenceState.UNLOAD
def forward(self, x, bsz_tensor=None):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
y = self.prefill_linear.forward(x, bsz_tensor)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
y = self.generate_linear.forward(x, bsz_tensor)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
if not mode:
mode = InferenceState.GENERATE
# load to device
if mode == InferenceState.PREFILL:
self.generate_linear.unload()
self.prefill_linear.load(w=w)
self.device = self.prefill_linear.device
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
self.input_scale = self.prefill_linear.input_scale
self.input_offset = self.prefill_linear.input_offset
self.quant_bias = self.prefill_linear.quant_bias
self.deq_scale = self.prefill_linear.deq_scale
self.weight_scale = self.prefill_linear.weight_scale
self.weight_offset = self.prefill_linear.weight_offset
elif mode == InferenceState.GENERATE:
self.prefill_linear.unload()
self.generate_linear.load(w=w)
self.device = self.generate_linear.device
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
self.input_scale = self.generate_linear.input_scale
self.input_offset = self.generate_linear.input_offset
self.quant_bias = self.generate_linear.quant_bias
self.deq_scale = self.generate_linear.deq_scale
self.weight_scale = self.generate_linear.weight_scale
self.weight_offset = self.generate_linear.weight_offset
elif mode == InferenceState.UNLOAD:
self.prefill_linear.unload()
self.generate_linear.unload()
self.device = "cpu"
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
self.mode = mode
def unload(self):
if self.prefill_linear is not None:
self.prefill_linear.unload()
if self.generate_linear is not None:
self.generate_linear.unload()
self.device = self.generate_linear.device
def set_inference_mode(self, mode: InferenceState):
if not mode:
mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")

View file

@ -0,0 +1,72 @@
import torch
import torch_npu
from ktransformers.util.ascend.ascend_utils import allreduce_wrapper
from ktransformers.util.utils import CUR_DEVICE
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP):
@allreduce_wrapper
def forward(self, x, is_prefill=None, use_cuda_graph=False):
original_dtype = x.dtype
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
dynamic_scale = dynamic_scale.view(-1)
gate_x = torch_npu.npu_quant_matmul(
quant_out,
self.orig_module.gate_proj.weight,
self.orig_module.gate_proj.weight_scale,
pertoken_scale=dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
up_x = torch_npu.npu_quant_matmul(
quant_out,
self.orig_module.up_proj.weight,
self.orig_module.up_proj.weight_scale,
pertoken_scale=dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
down_x = self.act_fn(gate_x) * up_x
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
down_dynamic_scale = down_dynamic_scale.view(-1)
down_proj = torch_npu.npu_quant_matmul(
down_quant_out,
self.orig_module.down_proj.weight,
self.orig_module.down_proj.weight_scale,
pertoken_scale=down_dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
return down_proj
class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP):
@allreduce_wrapper
def forward(self, x, is_prefill=None, use_cuda_graph=False):
original_dtype = x.dtype
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
dynamic_scale = dynamic_scale.view(-1)
gate_up_x = torch_npu.npu_quant_matmul(
quant_out,
self.orig_module.gate_proj.weight,
self.orig_module.gate_proj.weight_scale,
pertoken_scale=dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
down_x = torch_npu.npu_swiglu(gate_up_x, -1)
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
down_dynamic_scale = down_dynamic_scale.view(-1)
down_proj = torch_npu.npu_quant_matmul(
down_quant_out,
self.orig_module.down_proj.weight,
self.orig_module.down_proj.weight_scale,
pertoken_scale=down_dynamic_scale,
bias=None,
output_dtype=original_dtype,
)
return down_proj

View file

@ -0,0 +1,114 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorchW8A8A2"
prefill_op: "KLinearTorchW8A8A2"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types
kwargs:
generate_device: "npu"
prefill_device: "npu"
generate_op: "KLinearTorchW8A8A2"
prefill_op: "KLinearTorchW8A8A2"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8 # mlp module with custom forward function
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^model\\.layers\\.([0-2])\\.mlp$"
class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP"
replace:
class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1"
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
name: "^model\\.layers\\..*\\.mlp\\.shared_experts$"
class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP"
replace:
class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2"
kwargs:
generate_device: "npu"
prefill_device: "npu"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2
kwargs:
generate_device: "npu:0"
prefill_device: "npu:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8
kwargs:
prefill_device: "npu"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPUW8A8"
out_device: "npu"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
class: ktransformers.operators.experts.KExpertsCPU
replace:
class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2 # optimized MLA implementation
kwargs:
generate_device: "npu"
prefill_device: "npu"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model..*norm"
replace:
class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8
kwargs:
generate_device: "npu"
prefill_device: "npu"

View file

@ -41,6 +41,13 @@ except ImportError:
MUSA_HOME=None
KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()
try:
import torch_npu
KTRANSFORMERS_BUILD_NPU = torch_npu.npu.is_available()
except ModuleNotFoundError | ImportError as e:
KTRANSFORMERS_BUILD_NPU = False
# 检测 DEV_BACKEND 环境变量
dev_backend = os.environ.get("DEV_BACKEND", "").lower()
if dev_backend == "xpu":
@ -237,6 +244,8 @@ class VersionInfo:
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
elif torch.xpu.is_available():
backend_version = f"xpu"
elif KTRANSFORMERS_BUILD_NPU:
backend_version = f"npu{torch_npu.__version__}"
else:
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
@ -509,6 +518,8 @@ class CMakeBuild(BuildExtension):
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
elif KTRANSFORMERS_BUILD_XPU:
cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
elif KTRANSFORMERS_BUILD_NPU:
cmake_args += ["-DKTRANSFORMERS_USE_NPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
else:
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.")
@ -636,10 +647,12 @@ elif MUSA_HOME is not None:
)
elif torch.xpu.is_available(): #XPUExtension is not available now.
ops_module = None
elif KTRANSFORMERS_BUILD_NPU:
pass
else:
raise ValueError("Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.")
if not torch.xpu.is_available():
if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU:
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
ops_module,
@ -660,10 +673,20 @@ if not torch.xpu.is_available():
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
else:
elif torch.xpu.is_available():
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
]
elif KTRANSFORMERS_BUILD_NPU:
ext_modules = [
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
]
if with_balance:
print("using balance_serve")
ext_modules.append(
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
)
setup(
name=VersionInfo.PACKAGE_NAME,