mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-08 13:39:48 +00:00
support npu
This commit is contained in:
parent
1677e90092
commit
dd0e41b3b8
14 changed files with 1453 additions and 5 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -26,4 +26,7 @@ ktransformers/tests/chat_txt.txt
|
|||
mmlu_result*
|
||||
ktransformers/ktransformers_ext/cuda_musa/
|
||||
test_prompt.txt
|
||||
csrc/demo
|
||||
csrc/demo
|
||||
CMakeFiles
|
||||
kvc2/
|
||||
sched/
|
|
@ -42,6 +42,7 @@ option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA"
|
|||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
||||
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
|
||||
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
|
@ -306,6 +307,17 @@ elseif (UNIX)
|
|||
endif()
|
||||
elseif (KTRANSFORMERS_USE_XPU)
|
||||
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
|
||||
elseif (KTRANSFORMERS_USE_NPU)
|
||||
include(CheckLanguage)
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
message(STATUS "CUDA detected")
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||
endif()
|
||||
message(STATUS "enabling CUDA")
|
||||
enable_language(CUDA)
|
||||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
else()
|
||||
find_package(CUDA REQUIRED)
|
||||
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
**/
|
||||
// Python bindings
|
||||
#include "cpu_backend/cpuinfer.h"
|
||||
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
|
||||
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU)
|
||||
#include "device_launch_parameters.h"
|
||||
#endif
|
||||
#include "llamafile/flags.h"
|
||||
|
|
165
doc/zh/DeepSeekR1_tutorial_zh_for_Ascend_NPU.md
Normal file
165
doc/zh/DeepSeekR1_tutorial_zh_for_Ascend_NPU.md
Normal file
|
@ -0,0 +1,165 @@
|
|||
# 部署
|
||||
|
||||
## 物理机安装
|
||||
|
||||
部署满血版DeepseekV3,需要机器物理内存能够存放下全部路由专家的权重,约400GB。
|
||||
|
||||
目前支持的NPU型号:**800I A2**。
|
||||
|
||||
在技术人员的支持下完成硬件安装。
|
||||
|
||||
## 系统安装
|
||||
|
||||
根据网页[昇腾兼容性查询助手](https://www.hiascend.com/hardware/compatibility)查询,选用系统Ubuntu 22.04 for aarch64,内核5.15.0-25-generic,并禁止系统自动更新。系统镜像获取链接:[ubuntu-old-releases](https://mirrors.aliyun.com/oldubuntu-releases/releases/22.04)。
|
||||
|
||||
## HDK安装
|
||||
|
||||
选择[Ascend HDK 25.0.RC1](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=32&cann=8.1.RC1.beta1&driver=Ascend+HDK+25.0.RC1)进行安装,安装方式参考[昇腾社区HDK安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1beta1/softwareinst/instg/instg_0005.html?Mode=PmIns&InstallType=local&OS=Ubuntu&Software=cannToolKit)。
|
||||
|
||||
|
||||
## Conda部署
|
||||
|
||||
建议按照最新[Installation Guide - kTransformers](https://kvcache-ai.github.io/ktransformers/en/install.html)部署开发环境,此处注意Python版本要求3.11(其他版本未验证),arm平台不需要安装cpufeature包。
|
||||
|
||||
安装conda/miniconda
|
||||
|
||||
```bash
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
|
||||
bash ~/Miniconda3-latest-Linux-aarch64.sh
|
||||
```
|
||||
|
||||
部署Python环境:
|
||||
|
||||
```bash
|
||||
conda create -n py311 python=3.11
|
||||
conda activate py311
|
||||
conda install -c conda-forge libstdcxx-ng # 安装`GLIBCXX-3.4.32`
|
||||
pip3 install numpy==1.26.4 # 适配torch/torch_npu
|
||||
pip3 install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
pip3 install packaging ninja transformers==4.43.2 fire protobuf attrs decorator cloudpickle ml-dtypes scipy tornado absl-py psutil
|
||||
#pip3 install cpufeature # only for x86
|
||||
```
|
||||
|
||||
## CANN安装
|
||||
|
||||
选择[CANN 8.1.RC1.beta1](https://www.hiascend.com/developer/download/community/result?from=firmware&product=4&model=32&cann=8.1.RC1.beta1)进行安装,安装方式参考[昇腾社区CANN安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1beta1/softwareinst/instg/instg_0007.html?Mode=PmIns&InstallType=local&OS=Ubuntu&Software=cannToolKit)。
|
||||
|
||||
需要安装ToolKit,Kernel和NNAL。
|
||||
|
||||
## torch_npu(op-plugin)安装
|
||||
|
||||
获取最新的仓库代码:[op-plugin Gitee](https://gitee.com/ascend/op-plugin)
|
||||
|
||||
由于涉及新增算子,公网pypi内提供的torch_npu暂时无法直接使用,需要使用适配过的op-plugin来编译生成所需的torch_npu包,目前还无法从公网获取。 # TODO
|
||||
|
||||
在访问github和gitee的网络通畅时,执行下述代码完成编译和安装torch_npu:
|
||||
|
||||
```bash
|
||||
cd op-plugin
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
|
||||
source /usr/local/Ascend/nnal/atb/set_env.sh # 以实际NNAL安装路径为准
|
||||
bash install.sh --python=3.11 --pytorch=v2.3.1-7.0.0 # 实际生成的torch_npu whl安装文件在{op-plugin项目地址}/dist中
|
||||
```
|
||||
|
||||
## 权重准备
|
||||
|
||||
目前,为了满足性能和精度的要求,我们需要准备两份权重,并使用提供的权重合并脚本对权重进行合并,最终只会使用合并后的权重。
|
||||
|
||||
Q4权重:[DeepSeek-R1-Q4_K_M](https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/files)
|
||||
|
||||
W8A8权重:[DeepSeek-R1-W8A8](https://modelers.cn/models/MindSpore-Lab/DeepSeek-R1-W8A8/tree/main)
|
||||
|
||||
使用[merge_safetensor_gguf.py](../../merge_tensors/merge_safetensor_gguf.py)来合并Q4和W8A8权重:
|
||||
|
||||
```bash
|
||||
python merge_safetensor_gguf.py --safetensor_path /mnt/weights/DeepSeek-R1-Q4_K_M --gguf_path /mnt/weights/DeepSeek-R1-W8A8 --output_path /mnt/weights/DeepSeek-R1-q4km-w8a8 --safetensors_format w8a8
|
||||
```
|
||||
|
||||
## 图下沉部署
|
||||
|
||||
图下沉所需的二进制文件随仓库给出:[ktransformers/util/npu_graph_so](../../ktransformers/util/npu_graph_so)。
|
||||
|
||||
部署图下沉功能,需要做相关文件替换,以arm平台为例:
|
||||
|
||||
```bash
|
||||
mv /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so.bak
|
||||
cp ktransformers/util/npu_graph_so/arm/libruntime.so /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so
|
||||
```
|
||||
|
||||
开启图下沉功能,需要添加如下环境变量:
|
||||
|
||||
```bash
|
||||
export CAPTURE_PLUGIN_PATH=ktransformers/util/npu_graph_so/arm
|
||||
export TASK_QUEUE_ENABLE=0 # 保证算子下发顺序有序
|
||||
```
|
||||
|
||||
|
||||
## kTransformers部署
|
||||
|
||||
将项目文件部署到机器上:
|
||||
|
||||
- 对于arm平台,注释掉`./requirements-local_chat.txt`中的`cpufeature`。
|
||||
- 对于arm平台,做如下替换:
|
||||
```bash
|
||||
cp ./for_arm/CMakeLists.txt ./csrc/ktransformers_ext/CMakeLists.txt
|
||||
cp ./for_arm/iqk_mul_mat.inc ./third_party/llamafile/iqk_mul_mat.inc
|
||||
cp ./for_arm/sgemm.cpp ./third_party/llamafile/sgemm.cpp
|
||||
cp ./for_arm/tinyblas_cpu_sgemm.inc ./third_party/llamafile/tinyblas_cpu_sgemm.inc
|
||||
cp ./for_arm/setup.py ./setup.py
|
||||
```
|
||||
- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`(以实际CANN-TOOLKIT安装路径为准)。
|
||||
- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。
|
||||
- 执行`bash install.sh`,等待安装完成。
|
||||
|
||||
此处给出示例local_chat的启动脚本(由于使用了相对路径,需将该脚本放至项目的根路径下):
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
export CAPTURE_PLUGIN_PATH=ktransformers/util/npu_graph_so/arm
|
||||
export USE_MERGE=0
|
||||
export INF_NAN_MODE_FORCE_DISABLE=1
|
||||
export TASK_QUEUE_ENABLE=0
|
||||
#export PROF_DECODE=1
|
||||
#export PROF_PREFILL=1
|
||||
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
source /usr/local/Ascend/nnal/atb/set_env.sh
|
||||
|
||||
torchrun --nproc_per_node 1 \
|
||||
--master_port 25565 \
|
||||
-m ktransformers.local_chat \
|
||||
--cpu_infer 20 \
|
||||
--model_path /mnt/weights/DeepSeek-R1-q4km-w8a8 \
|
||||
--gguf_path /mnt/weights/DeepSeek-R1-q4km-w8a8 \
|
||||
--optimize_config_path ./ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-800IA2-npu.yaml \
|
||||
--use_cuda_graph True \
|
||||
--max_new_tokens 500 \
|
||||
--tp 1
|
||||
```
|
||||
|
||||
相关参数说明:
|
||||
|
||||
- `--model_path`:kTransformers原生参数,str,此处用来指定合并后的模型文件路径
|
||||
- `--gguf_path`:kTransformers原生参数,str,此处用来指定合并后的模型文件路径
|
||||
- `--cpu_infer`:kTransformers原生参数,int,用来控制CPU侧实际worker线程数,非必选
|
||||
- `--optimize_config_path`:kTransformers原生参数,str,用来指定所用的模型优化配置文件,需要注意相对路径的使用,此处为**必选**
|
||||
- `--use_cuda_graph`:kTransformers原生参数,bool,为True表示开启图下沉,为False表示关闭图下沉,非必选
|
||||
- `--max_new_tokens`:kTransformers原生参数,int,当统计到输出的tokens数量达到该值时,会直接中止输出,非必选
|
||||
- `--tp`:新增参数,int,用于开启tensor model parallel功能,目前local_chat只支持tp大小与ws大小相同(不支持local_chat使用多dp),非必选
|
||||
|
||||
|
||||
# 其他问题
|
||||
|
||||
## 可能存在的其他依赖问题
|
||||
|
||||
ImportError: libhccl.so: cannot open shared object file: No such file or directory
|
||||
|
||||
```bash
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
|
||||
```
|
||||
|
||||
ImportError: libascend_hal.so: cannot open shared object file: No such file or directory
|
||||
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH # 以实际Driver安装路径为准
|
||||
```
|
21
install_for_npu.sh
Normal file
21
install_for_npu.sh
Normal file
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
# clear build dirs
|
||||
rm -rf build
|
||||
rm -rf *.egg-info
|
||||
rm -rf csrc/build
|
||||
rm -rf csrc/ktransformers_ext/build
|
||||
rm -rf csrc/ktransformers_ext/cuda/build
|
||||
rm -rf csrc/ktransformers_ext/cuda/dist
|
||||
rm -rf csrc/ktransformers_ext/cuda/*.egg-info
|
||||
rm -rf ~/.ktransformers
|
||||
echo "Installing python dependencies from requirements.txt"
|
||||
pip install -r requirements-local_chat.txt
|
||||
pip install -r ktransformers/server/requirements.txt
|
||||
echo "Installing ktransformers"
|
||||
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
|
||||
pip install third_party/custom_flashinfer/
|
||||
|
||||
echo "Installation completed successfully"
|
|
@ -67,7 +67,7 @@ attn:
|
|||
page_size: 256
|
||||
chunk_size: 256
|
||||
kvc2:
|
||||
gpu_only: false
|
||||
gpu_only: true
|
||||
utilization_percentage: 1.0
|
||||
cpu_memory_size_GB: 500
|
||||
disk_path: /mnt/data/kvc
|
467
ktransformers/operators/ascend/ascend_attention.py
Normal file
467
ktransformers/operators/ascend/ascend_attention.py
Normal file
|
@ -0,0 +1,467 @@
|
|||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, allreduce_wrapper
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import get_compute_capability, get_use_npu_graph, CUR_DEVICE
|
||||
from ktransformers.util.vendors import device_manager, GPUVendor
|
||||
from ktransformers.util import utils
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_fusion(q, k, cos, sin, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class MatMulOps(object):
|
||||
def execute(self, x_input):
|
||||
"""
|
||||
:param x, weight, quant_bia, deq_scale
|
||||
:return:
|
||||
"""
|
||||
quant_out = x_input[0]
|
||||
weight = x_input[1]
|
||||
quant_bia = x_input[2]
|
||||
deq_scale = x_input[3]
|
||||
return [torch_npu.npu_quant_matmul(quant_out, weight.T, deq_scale, bias=quant_bia, output_dtype=torch.float16)]
|
||||
|
||||
|
||||
class MatMulOpsAtb(object):
|
||||
def execute(self, x_input):
|
||||
"""
|
||||
:param x, weight, quant_bia, deq_scale
|
||||
:return:
|
||||
"""
|
||||
x = x_input[0]
|
||||
weight = x_input[1]
|
||||
quant_bia = x_input[2]
|
||||
deq_scale = x_input[3]
|
||||
target_shape = (x.shape[0], x.shape[-2], weight.shape[-2])
|
||||
target_tensor = torch.zeros(target_shape, dtype=torch.float16, device=x.device)
|
||||
torch_npu.torch_npu._npu_matmul_dequant(x, weight, quant_bia, deq_scale, target_tensor)
|
||||
return [target_tensor]
|
||||
|
||||
|
||||
class DynamicQuantOps(object):
|
||||
"""
|
||||
:param x, scale, offset
|
||||
:return
|
||||
"""
|
||||
|
||||
def execute(self, x_input):
|
||||
out = torch.empty_like(x_input[0], dtype=torch.int8)
|
||||
torch_npu._npu_quantize_per_tensor(x_input[0], x_input[1], x_input[2], out)
|
||||
return [out]
|
||||
|
||||
|
||||
class KDeepseekV2AttentionW8A8A2(BaseInjectedModule, DeepseekV2Attention):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
absorb_for_prefill: bool = False,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device,
|
||||
**kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
self.mla_wrapper = None
|
||||
tp = get_tensor_parallel_size()
|
||||
if tp > 1:
|
||||
self.num_heads //= tp
|
||||
self.absorb_for_prefill = absorb_for_prefill
|
||||
|
||||
self.use_merge = os.getenv("USE_MERGE", "0")
|
||||
if self.use_merge == "0":
|
||||
print("--Use ATB FA-MLA and PA-MLA OP !--")
|
||||
self.elewise_quant = DynamicQuantOps()
|
||||
self.matmulDequant_operation = MatMulOpsAtb()
|
||||
self.matmulDequant_operation_aclnn = MatMulOps()
|
||||
elif self.use_merge == "1":
|
||||
print("--Use torch npu FA OP !--")
|
||||
else:
|
||||
print("--Use default op! --")
|
||||
|
||||
@allreduce_wrapper
|
||||
def forward_chunck(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0]
|
||||
q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight,
|
||||
self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0]
|
||||
q_a_proj_out = self.q_a_layernorm(q_a_proj_out)
|
||||
q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0]
|
||||
q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight,
|
||||
self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0]
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0]
|
||||
compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight,
|
||||
self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0]
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = k_pe.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin)
|
||||
|
||||
# update KV
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe = k_pe.transpose(1, 2) # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
|
||||
compressed_kv = compressed_kv.unsqueeze(2) # compressed_kv [bsz, q_len, self.kv_lora_rank]
|
||||
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:, :, :attention_mask.size(-1), :]
|
||||
compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:, :, :attention_mask.size(-1), :]
|
||||
|
||||
weight_uk = self.q_absorb
|
||||
weight_uv = self.out_absorb
|
||||
|
||||
# ATB-MLA-FA+PA
|
||||
if self.use_merge == "0" and q_len != 1:
|
||||
current_sqenLen = past_key_value.get_seq_length(self.layer_idx)
|
||||
attention_mask = attention_mask[0, :, :, :current_sqenLen].squeeze(0).squeeze(0)
|
||||
|
||||
compressed_kv = compressed_kv[:, :, :current_sqenLen, :] # all KV until current chunk
|
||||
k_pe = k_pe[:, :, :current_sqenLen, :]
|
||||
|
||||
k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1)
|
||||
k_up = torch.matmul(compressed_kv, weight_uk.mT)
|
||||
v_up = torch.matmul(compressed_kv, weight_uv)
|
||||
|
||||
qTensor = torch.cat((q_nope, q_pe), dim=-1).transpose(1, 2).contiguous().view(
|
||||
bsz * q_len, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
kTensor = torch.cat((k_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view(
|
||||
bsz * current_sqenLen, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
vTensor = v_up.transpose(1, 2).contiguous().view(bsz * current_sqenLen, self.num_heads, self.v_head_dim)
|
||||
|
||||
seq_len_data = [q_len] * bsz
|
||||
seq_len = torch.tensor(seq_len_data, dtype=torch.int32, device=vTensor.device)
|
||||
seq_len_host = torch.tensor(seq_len_data, dtype=torch.int32)
|
||||
|
||||
attn_output = torch.ones((qTensor.shape[0], qTensor.shape[1], vTensor.shape[-1]),
|
||||
dtype=qTensor.dtype, device=vTensor.device)
|
||||
torch_npu._npu_flash_attention_mla(qTensor, kTensor, vTensor, attention_mask, seq_len, seq_len_host,
|
||||
self.softmax_scale, self.num_heads, self.num_heads, attn_output)
|
||||
|
||||
if attn_output.size() != (bsz * q_len, self.num_heads, self.v_head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.v_head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]
|
||||
attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,
|
||||
self.o_proj.quant_bias, self.o_proj.deq_scale])[0]
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
elif self.use_merge == "0" and q_len == 1:
|
||||
return self.forward_paged(q_pe=q_pe,
|
||||
q_nope=q_nope,
|
||||
compressed_kv_with_k_pe=compressed_kv_with_k_pe,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position)
|
||||
|
||||
if self.use_merge == "1":
|
||||
k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1)
|
||||
k_up = torch.matmul(compressed_kv, weight_uk.mT)
|
||||
v_up = torch.matmul(compressed_kv, weight_uv)
|
||||
qTensor = torch.cat((q_nope, q_pe), dim=-1)
|
||||
kTensor = torch.cat((k_up, k_pe_repeated), dim=-1)
|
||||
vTensor = torch.cat((v_up, k_pe_repeated), dim=-1)
|
||||
|
||||
if q_len != 1:
|
||||
attn_output = torch_npu.npu_prompt_flash_attention(
|
||||
qTensor, kTensor, vTensor,
|
||||
num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD")
|
||||
else:
|
||||
attn_output = torch_npu.npu_incre_flash_attention(
|
||||
qTensor, kTensor, vTensor,
|
||||
num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD")
|
||||
attn_output = attn_output[:, :, :, :self.v_head_dim]
|
||||
else:
|
||||
q_nope = torch.matmul(q_nope, self.q_absorb)
|
||||
|
||||
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
|
||||
|
||||
compressed_kv = compressed_kv.squeeze(1)
|
||||
"""
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
assert attention_mask is not None
|
||||
"""
|
||||
if attention_mask is not None:
|
||||
"""
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
"""
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(q_pe.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
||||
|
||||
attn_output = torch.matmul(attn_output, self.out_absorb)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_paged(
|
||||
self,
|
||||
q_pe: torch.Tensor,
|
||||
q_nope: torch.Tensor,
|
||||
compressed_kv_with_k_pe: torch.Tensor,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, _, q_len, _ = q_nope.size()
|
||||
q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb) # torch.Size([1, 128, 1, 512])
|
||||
compressed_kv = compressed_kv_with_k_pe.permute(0, 2, 1, 3)
|
||||
kvCache = compressed_kv[:, :, :, :self.kv_lora_rank].contiguous()
|
||||
kRopeCache = compressed_kv[:, :, :, self.kv_lora_rank:].contiguous()
|
||||
|
||||
if get_use_npu_graph():
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
|
||||
stream = npu_graph_runner.main_stream
|
||||
if npu_graph_runner.past_key_value is None:
|
||||
npu_graph_runner.past_key_value = past_key_value
|
||||
if npu_graph_runner.workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope,
|
||||
kvCache,
|
||||
kvCache,
|
||||
query_rope=q_pe,
|
||||
key_rope=kRopeCache,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=None,
|
||||
scale=self.softmax_scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=past_key_value.page_table_list[self.layer_idx],
|
||||
block_size=past_key_value.page_size,
|
||||
actual_seq_lengths_kv=past_key_value.position
|
||||
)
|
||||
npu_graph_runner.workspace = workspace
|
||||
attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=CUR_DEVICE)
|
||||
softmax_lse = torch.empty(1, dtype=torch.float16, device=CUR_DEVICE)
|
||||
npu_graph_runner.ifa_param.append((q_nope, kvCache, q_pe, kRopeCache, self.num_heads,
|
||||
self.softmax_scale, self.layer_idx, attn_output, softmax_lse))
|
||||
eventTmp = torch.npu.ExternalEvent()
|
||||
npu_graph_runner.event.append(eventTmp)
|
||||
eventTmp.wait(stream)
|
||||
eventTmp.reset(stream)
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
kvCache,
|
||||
kvCache,
|
||||
workspace=npu_graph_runner.workspace,
|
||||
query_rope=q_pe,
|
||||
key_rope=kRopeCache,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=None,
|
||||
scale=self.softmax_scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=past_key_value.page_table_list[self.layer_idx],
|
||||
block_size=past_key_value.page_size,
|
||||
actual_seq_lengths_kv=past_key_value.position,
|
||||
out=[attn_output, softmax_lse])
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
npu_graph_runner.handle.append(handle)
|
||||
else:
|
||||
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
kvCache,
|
||||
kvCache,
|
||||
query_rope=q_pe,
|
||||
key_rope=kRopeCache,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=1,
|
||||
input_layout="BNSD",
|
||||
atten_mask=None,
|
||||
scale=self.softmax_scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=past_key_value.page_table_list[self.layer_idx],
|
||||
block_size=past_key_value.page_size,
|
||||
actual_seq_lengths_kv=past_key_value.position,
|
||||
)
|
||||
|
||||
attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb)
|
||||
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]
|
||||
attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,
|
||||
self.o_proj.quant_bias, self.o_proj.deq_scale])[0]
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_windows(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if q_len <= self.chunck_size:
|
||||
return self.forward_chunck(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
assert output_attentions is False, "output_attentions is not supported when using chunked attention"
|
||||
attn_output = None
|
||||
cur_idx = 0
|
||||
while cur_idx < q_len:
|
||||
if attention_mask is not None:
|
||||
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
|
||||
else:
|
||||
# generate chunk_mask automatically.
|
||||
self.attn_mask = \
|
||||
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
|
||||
if self.attn_mask is None \
|
||||
else self.attn_mask
|
||||
self.attn_mask[:, :, :, cur_idx:min(cur_idx + self.chunck_size, past_key_value.max_cache_len)] = \
|
||||
-65504.0 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1) \
|
||||
[:, :min(self.chunck_size, min(past_key_value.max_cache_len - cur_idx, self.chunck_size))]
|
||||
self.attn_mask[:, :, :, cur_idx + self.chunck_size:] = -65504.0
|
||||
self.attn_mask[:, :, :, :cur_idx] = 0
|
||||
chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len - cur_idx))
|
||||
|
||||
cur_output, _, _ = self.forward_chunck(
|
||||
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
|
||||
chunk_mask,
|
||||
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
|
||||
**kwargs
|
||||
)
|
||||
cur_idx += self.chunck_size
|
||||
if attn_output is None:
|
||||
attn_output = cur_output
|
||||
else:
|
||||
attn_output = torch.cat((attn_output, cur_output), dim=-2)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
192
ktransformers/operators/ascend/ascend_experts.py
Normal file
192
ktransformers/operators/ascend/ascend_experts.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
import bisect
|
||||
|
||||
import acl
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, get_tensor_parallel_group
|
||||
from ktransformers.operators.experts import KExpertsCPU, KTransformersExperts, EXPERTS_MAP, KDeepseekV3MoE, cuda_graphs
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import CUR_DEVICE, get_use_npu_graph, InferenceState
|
||||
|
||||
|
||||
class KExpertsCPUW8A8(KExpertsCPU):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cpu",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.input_tensor_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
self.expert_ids_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
self.weights_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
self.output_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
self.bsz_tensor_cpu_graph = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
|
||||
if get_use_npu_graph():
|
||||
self.cpu_infer.submit(self.moe.forward(self.expert_ids_cpu_graph.size(0),
|
||||
self.expert_ids_cpu_graph.size(1),
|
||||
self.expert_ids_cpu_graph.data_ptr(),
|
||||
self.weights_cpu_graph.data_ptr(),
|
||||
self.input_tensor_cpu_graph.data_ptr(),
|
||||
self.output_cpu_graph.data_ptr(),
|
||||
self.bsz_tensor_cpu_graph.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
else:
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
|
||||
org_type = input_tensor.dtype
|
||||
input_tensor = input_tensor.contiguous().cpu()
|
||||
input_tensor = input_tensor.to(torch.bfloat16)
|
||||
expert_ids = expert_ids.contiguous().cpu()
|
||||
weights = weights.contiguous().to(torch.float32).cpu()
|
||||
bsz_tensor = bsz_tensor.contiguous().cpu()
|
||||
output = torch.empty_like(input_tensor).contiguous()
|
||||
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
return output.to(org_type).to(device=CUR_DEVICE)
|
||||
|
||||
|
||||
EXPERTS_MAP["KExpertsCPUW8A8"] = KExpertsCPUW8A8
|
||||
|
||||
|
||||
class KTransformersExpertsW8A8(KTransformersExperts):
|
||||
def forward(self, input_tensor, expert_ids, weights):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
|
||||
class KDeepseekV3MoEW8A8(KDeepseekV3MoE):
|
||||
def forward_tp(self, hidden_states):
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
rank = torch.distributed.get_rank()
|
||||
def share_experts_forward():
|
||||
if self.config.n_shared_experts is not None:
|
||||
return self.shared_experts(identity).squeeze(0)
|
||||
if rank == 0:
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1)
|
||||
if get_use_npu_graph():
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
|
||||
event = torch.npu.Event()
|
||||
event.record(npu_graph_runner.main_stream)
|
||||
with torch.npu.stream(npu_graph_runner.update_stream):
|
||||
event.wait(npu_graph_runner.update_stream)
|
||||
y_ = share_experts_forward() if share_experts_forward is not None else None
|
||||
event.record(npu_graph_runner.update_stream)
|
||||
org_type = hidden_states.dtype
|
||||
input_tensor = hidden_states.to(torch.bfloat16)
|
||||
topk_weight = topk_weight.contiguous().to(torch.float32)
|
||||
self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight)
|
||||
self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True)
|
||||
self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True)
|
||||
self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True)
|
||||
|
||||
npu_graph_runner.launch_callback(
|
||||
self.cpu_moe_kexperts,
|
||||
self.moe_kexperts_param,
|
||||
1, npu_graph_runner.stream)
|
||||
|
||||
output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True)
|
||||
y = output_npu_graph.to(org_type)
|
||||
event.wait(npu_graph_runner.main_stream)
|
||||
else:
|
||||
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)
|
||||
y_ = share_experts_forward() if share_experts_forward is not None else None
|
||||
y = y.view(*orig_shape).to(device=hidden_states.device)
|
||||
else:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
y = torch.zeros(orig_shape, dtype=torch.float16, device=CUR_DEVICE)
|
||||
y_ = share_experts_forward() if share_experts_forward is not None else None
|
||||
torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group())
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y
|
||||
|
||||
def forward(self, hidden_states):
|
||||
tp_size = get_tensor_parallel_size()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if tp_size > 1 and world_size == tp_size:
|
||||
return self.forward_tp(hidden_states)
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
y_ = None
|
||||
|
||||
# only for generate phase
|
||||
# if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and False:
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
def share_experts_forward():
|
||||
if self.config.n_shared_experts is not None:
|
||||
return self.shared_experts(identity).squeeze(0)
|
||||
|
||||
cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1)
|
||||
if get_use_npu_graph():
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
|
||||
event = torch.npu.Event()
|
||||
event.record(npu_graph_runner.main_stream)
|
||||
with torch.npu.stream(npu_graph_runner.update_stream):
|
||||
event.wait(npu_graph_runner.update_stream)
|
||||
y_ = share_experts_forward() if share_experts_forward is not None else None
|
||||
event.record(npu_graph_runner.update_stream)
|
||||
org_type = hidden_states.dtype
|
||||
input_tensor = hidden_states.to(torch.bfloat16)
|
||||
topk_weight = topk_weight.contiguous().to(torch.float32)
|
||||
self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight)
|
||||
self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True)
|
||||
self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True)
|
||||
self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True)
|
||||
|
||||
npu_graph_runner.launch_callback(
|
||||
self.cpu_moe_kexperts,
|
||||
self.moe_kexperts_param,
|
||||
1, npu_graph_runner.stream)
|
||||
|
||||
output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True)
|
||||
y = output_npu_graph.to(org_type)
|
||||
event.wait(npu_graph_runner.main_stream)
|
||||
else:
|
||||
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)
|
||||
y_ = share_experts_forward() if share_experts_forward is not None else None
|
||||
y = y.view(*orig_shape).to(device=hidden_states.device)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def cpu_moe_kexperts(self, moe_kexperts_param) -> torch.Tensor:
|
||||
x, topk_ids, topk_weight = moe_kexperts_param
|
||||
self.moe_kexperts(x, topk_ids, topk_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
return outs
|
43
ktransformers/operators/ascend/ascend_gate.py
Normal file
43
ktransformers/operators/ascend/ascend_gate.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
import torch_npu
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ktransformers.operators.gate import KMoEGate
|
||||
from ktransformers.util import utils
|
||||
|
||||
|
||||
class KDeepseekV3GateA2(KMoEGate):
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
|
||||
device = utils.CUR_DEVICE
|
||||
if device is None:
|
||||
device = self.device
|
||||
if w is None:
|
||||
w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32))
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
h = hidden_states.shape[-1]
|
||||
# compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight, None)
|
||||
topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k(
|
||||
logits,
|
||||
k=self.top_k,
|
||||
bias=self.e_score_correction_bias,
|
||||
k_group=self.topk_group,
|
||||
group_count=self.n_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
eps=float(1e-20))
|
||||
return topk_idx.type(torch.int64), topk_weight
|
38
ktransformers/operators/ascend/ascend_layernorm.py
Normal file
38
ktransformers/operators/ascend/ascend_layernorm.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
|
||||
|
||||
class KDeepseekV3RMSNormW8A8(BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "npu",
|
||||
generate_device: str = "npu",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.weight = nn.Parameter(torch.ones(self.orig_module.hidden_size))
|
||||
self.bias = nn.Parameter(torch.ones(self.orig_module.hidden_size))
|
||||
self.variance_epsilon = self.orig_module.variance_epsilon
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
out = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + self.bias
|
||||
return out.to(input_dtype)
|
||||
|
||||
def load(self):
|
||||
self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(utils.CUR_DEVICE)
|
||||
self.bias = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".bias").to(utils.CUR_DEVICE)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.bias is not None:
|
||||
self.bias = None
|
298
ktransformers/operators/ascend/ascend_linear.py
Normal file
298
ktransformers/operators/ascend/ascend_linear.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.operators.linear import KLinearBase, LINEAR_MAP
|
||||
from ktransformers.util.ascend.ascend_utils import (
|
||||
get_safetensors_cut_weight,
|
||||
get_tensor_parallel_size,
|
||||
get_tensor_parallel_group
|
||||
)
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
|
||||
|
||||
class KLinearW8A8(KLinearBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
|
||||
def load_weight(self, override_key: str | None = None, device: str | None = None):
|
||||
if override_key is not None:
|
||||
keys = override_key
|
||||
else:
|
||||
keys = [self.key]
|
||||
fake_tensor = torch.tensor([1])
|
||||
for key in keys:
|
||||
if device is None:
|
||||
device = utils.CUR_DEVICE
|
||||
if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
||||
deq_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.deq_scale")
|
||||
quant_bias = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.quant_bias")
|
||||
input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale")
|
||||
input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset")
|
||||
tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset)
|
||||
return tensors
|
||||
elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
if key.endswith("ffn_gate_shexp"):
|
||||
parts = key.split(".")
|
||||
layer = parts[1]
|
||||
gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight")
|
||||
gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t()
|
||||
up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight")
|
||||
up_weight = get_safetensors_cut_weight(self.key, up_weight).t()
|
||||
gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale")
|
||||
gate_scale = get_safetensors_cut_weight(self.key, gate_scale)
|
||||
up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale")
|
||||
up_scale = get_safetensors_cut_weight(self.key, up_scale)
|
||||
gate_up_weight = torch.cat((gate_weight, up_weight), 1)
|
||||
gate_up_scale = torch.cat((gate_scale, up_scale), 0)
|
||||
gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset")
|
||||
up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset")
|
||||
gate_up_offset = torch.cat((gate_offset, up_offset), 0)
|
||||
tensors = (gate_up_weight, gate_up_scale, gate_up_offset)
|
||||
elif key.endswith("ffn_up_shexp"):
|
||||
return fake_tensor
|
||||
else:
|
||||
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
||||
weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale")
|
||||
weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset")
|
||||
tensors = (qweight, weight_scale, weight_offset)
|
||||
return tensors
|
||||
else:
|
||||
weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
||||
weight = get_safetensors_cut_weight(self.key, weight)
|
||||
return weight
|
||||
else:
|
||||
raise FileNotFoundError(f"Weight file not found for key {key}")
|
||||
|
||||
@abstractmethod
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = "cuda"):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unload(self):
|
||||
pass
|
||||
|
||||
|
||||
class KLinearTorchW8A8A2(KLinearW8A8):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.weight = None
|
||||
self.input_scale = None
|
||||
self.input_offset = None
|
||||
self.quant_bias = None
|
||||
self.deq_scale = None
|
||||
self.weight_scale = None
|
||||
self.weight_offset = None
|
||||
|
||||
def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor:
|
||||
tp = get_tensor_parallel_size()
|
||||
if tp == 1:
|
||||
out = torch.zeros((x.shape[0], x.shape[1], self.weight.shape[-1]), dtype=torch.float16, device=x.device)
|
||||
torch_npu._npu_matmul_pp(x, self.weight, out)
|
||||
else:
|
||||
tp_size = get_tensor_parallel_size()
|
||||
tp_group = get_tensor_parallel_group()
|
||||
batch_size = x.shape[0]
|
||||
seq_length = x.shape[1]
|
||||
lm_sep_size = tp_size
|
||||
lm_head_group = tp_group
|
||||
gathered_list = [torch.empty_like(x) for _ in range(lm_sep_size)]
|
||||
dist.all_gather(gathered_list, x, group=lm_head_group)
|
||||
input_full = torch.stack(gathered_list, dim=0)
|
||||
input_full = input_full.squeeze(dim=1)
|
||||
torch_npu.npu_format_cast_(input_full, 2)
|
||||
local_logits = torch.zeros((input_full.shape[0], input_full.shape[1], self.weight.shape[-1]),
|
||||
dtype=torch.float16, device=input_full.device)
|
||||
torch_npu._npu_matmul_pp(input_full, self.weight, local_logits)
|
||||
local_logits_transpose = local_logits.transpose(2, 1).reshape(-1, batch_size * seq_length)
|
||||
del local_logits
|
||||
output_tensor = torch.empty_like(local_logits_transpose)
|
||||
dist.all_to_all_single(output_tensor, local_logits_transpose, group=lm_head_group)
|
||||
del local_logits_transpose
|
||||
output_tensor = output_tensor.transpose(1, 0)
|
||||
out = output_tensor.view(batch_size, seq_length, -1)
|
||||
del output_tensor
|
||||
return out
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
device = utils.CUR_DEVICE
|
||||
if w is None:
|
||||
w = self.load_weight()
|
||||
if isinstance(w, nn.Parameter):
|
||||
try:
|
||||
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T.contiguous()
|
||||
except:
|
||||
self.weight = w.to(dtype=self.dtype).T.contiguous()
|
||||
self.weight = self.weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
elif isinstance(w, tuple):
|
||||
w_list = list(w)
|
||||
if len(w_list) == 3:
|
||||
self.weight = w_list[0]
|
||||
self.weight_scale = w_list[1].view(-1)
|
||||
self.weight_offset = w_list[2]
|
||||
self.weight = self.weight.to(utils.CUR_DEVICE)
|
||||
self.weight_scale = self.weight_scale.to(utils.CUR_DEVICE)
|
||||
if self.key.endswith("ffn_gate_shexp") is not True:
|
||||
self.weight = get_safetensors_cut_weight(self.key, self.weight).t()
|
||||
weight_scale = get_safetensors_cut_weight(self.key, self.weight_scale)
|
||||
self.weight_scale = weight_scale.clone()
|
||||
del weight_scale
|
||||
self.weight_offset = self.weight_offset.to(utils.CUR_DEVICE)
|
||||
else:
|
||||
for i in range(len(w_list)):
|
||||
w_list[i] = get_safetensors_cut_weight(self.key, w_list[i])
|
||||
w_list[i] = w_list[i].to(utils.CUR_DEVICE)
|
||||
self.weight = w_list[0]
|
||||
self.deq_scale = w_list[1]
|
||||
self.quant_bias = w_list[2]
|
||||
if "attn_output" in self.key or "ffn_down" in self.key:
|
||||
if torch.distributed.get_rank(get_tensor_parallel_group()) != 0:
|
||||
self.quant_bias = torch.zeros_like(self.quant_bias, dtype=self.quant_bias.dtype, device=self.quant_bias.device)
|
||||
self.input_scale = w_list[3]
|
||||
self.input_offset = w_list[4]
|
||||
elif isinstance(w, torch.Tensor):
|
||||
self.weight = w.T.contiguous()
|
||||
self.weight.to(device)
|
||||
if "kv_b" not in self.key:
|
||||
self.weight = self.weight.to(device)
|
||||
torch_npu.npu_format_cast_(self.weight, 29)
|
||||
else:
|
||||
raise ValueError(f"Invalid weight type {self.key=} {type(w)=}")
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
self.input_scale = None
|
||||
self.input_offset = None
|
||||
self.quant_bias = None
|
||||
self.deq_scale = None
|
||||
self.weight_scale = None
|
||||
self.weight_offset = None
|
||||
|
||||
|
||||
LINEAR_MAP["KLinearTorchW8A8A2"] = KLinearTorchW8A8A2
|
||||
|
||||
|
||||
class KTransformersLinearW8A8A2(BaseInjectedModule, KLinearW8A8):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
generate_device: str = "cuda",
|
||||
generate_op: str | None = "KLinearMarlin",
|
||||
prefill_device: str = "cuda",
|
||||
prefill_op: str | None = "KLinearTorch",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KLinearW8A8.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
# build all the linear operators
|
||||
if prefill_op is not None:
|
||||
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
|
||||
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
else:
|
||||
self.prefill_linear = None
|
||||
|
||||
if generate_op is not None:
|
||||
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
|
||||
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_linear = None
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def forward(self, x, bsz_tensor=None):
|
||||
if self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||
y = self.prefill_linear.forward(x, bsz_tensor)
|
||||
else:
|
||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||
y = self.generate_linear.forward(x, bsz_tensor)
|
||||
return y
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||
if not mode:
|
||||
mode = InferenceState.GENERATE
|
||||
# load to device
|
||||
if mode == InferenceState.PREFILL:
|
||||
self.generate_linear.unload()
|
||||
self.prefill_linear.load(w=w)
|
||||
self.device = self.prefill_linear.device
|
||||
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
|
||||
self.input_scale = self.prefill_linear.input_scale
|
||||
self.input_offset = self.prefill_linear.input_offset
|
||||
self.quant_bias = self.prefill_linear.quant_bias
|
||||
self.deq_scale = self.prefill_linear.deq_scale
|
||||
self.weight_scale = self.prefill_linear.weight_scale
|
||||
self.weight_offset = self.prefill_linear.weight_offset
|
||||
elif mode == InferenceState.GENERATE:
|
||||
self.prefill_linear.unload()
|
||||
self.generate_linear.load(w=w)
|
||||
self.device = self.generate_linear.device
|
||||
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
|
||||
self.input_scale = self.generate_linear.input_scale
|
||||
self.input_offset = self.generate_linear.input_offset
|
||||
self.quant_bias = self.generate_linear.quant_bias
|
||||
self.deq_scale = self.generate_linear.deq_scale
|
||||
self.weight_scale = self.generate_linear.weight_scale
|
||||
self.weight_offset = self.generate_linear.weight_offset
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.prefill_linear.unload()
|
||||
self.generate_linear.unload()
|
||||
self.device = "cpu"
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
self.mode = mode
|
||||
|
||||
def unload(self):
|
||||
if self.prefill_linear is not None:
|
||||
self.prefill_linear.unload()
|
||||
if self.generate_linear is not None:
|
||||
self.generate_linear.unload()
|
||||
self.device = self.generate_linear.device
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if not mode:
|
||||
mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
72
ktransformers/operators/ascend/ascend_mlp.py
Normal file
72
ktransformers/operators/ascend/ascend_mlp.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import torch
|
||||
import torch_npu
|
||||
|
||||
from ktransformers.util.ascend.ascend_utils import allreduce_wrapper
|
||||
from ktransformers.util.utils import CUR_DEVICE
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
|
||||
|
||||
class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP):
|
||||
@allreduce_wrapper
|
||||
def forward(self, x, is_prefill=None, use_cuda_graph=False):
|
||||
original_dtype = x.dtype
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
dynamic_scale = dynamic_scale.view(-1)
|
||||
gate_x = torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
self.orig_module.gate_proj.weight,
|
||||
self.orig_module.gate_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
up_x = torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
self.orig_module.up_proj.weight,
|
||||
self.orig_module.up_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
down_x = self.act_fn(gate_x) * up_x
|
||||
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
|
||||
down_dynamic_scale = down_dynamic_scale.view(-1)
|
||||
down_proj = torch_npu.npu_quant_matmul(
|
||||
down_quant_out,
|
||||
self.orig_module.down_proj.weight,
|
||||
self.orig_module.down_proj.weight_scale,
|
||||
pertoken_scale=down_dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
return down_proj
|
||||
|
||||
|
||||
class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP):
|
||||
@allreduce_wrapper
|
||||
def forward(self, x, is_prefill=None, use_cuda_graph=False):
|
||||
original_dtype = x.dtype
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
dynamic_scale = dynamic_scale.view(-1)
|
||||
gate_up_x = torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
self.orig_module.gate_proj.weight,
|
||||
self.orig_module.gate_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
down_x = torch_npu.npu_swiglu(gate_up_x, -1)
|
||||
|
||||
down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)
|
||||
down_dynamic_scale = down_dynamic_scale.view(-1)
|
||||
down_proj = torch_npu.npu_quant_matmul(
|
||||
down_quant_out,
|
||||
self.orig_module.down_proj.weight,
|
||||
self.orig_module.down_proj.weight_scale,
|
||||
pertoken_scale=down_dynamic_scale,
|
||||
bias=None,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
return down_proj
|
|
@ -0,0 +1,114 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorchW8A8A2"
|
||||
prefill_op: "KLinearTorchW8A8A2"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
generate_op: "KLinearTorchW8A8A2"
|
||||
prefill_op: "KLinearTorchW8A8A2"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8 # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([0-2])\\.mlp$"
|
||||
class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP"
|
||||
replace:
|
||||
class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1"
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.shared_experts$"
|
||||
class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP"
|
||||
replace:
|
||||
class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2"
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2
|
||||
kwargs:
|
||||
generate_device: "npu:0"
|
||||
prefill_device: "npu:0"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8
|
||||
kwargs:
|
||||
prefill_device: "npu"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPUW8A8"
|
||||
out_device: "npu"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
class: ktransformers.operators.experts.KExpertsCPU
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2 # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
name: "^model..*norm"
|
||||
replace:
|
||||
class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8
|
||||
kwargs:
|
||||
generate_device: "npu"
|
||||
prefill_device: "npu"
|
27
setup.py
27
setup.py
|
@ -41,6 +41,13 @@ except ImportError:
|
|||
MUSA_HOME=None
|
||||
KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()
|
||||
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
KTRANSFORMERS_BUILD_NPU = torch_npu.npu.is_available()
|
||||
except ModuleNotFoundError | ImportError as e:
|
||||
KTRANSFORMERS_BUILD_NPU = False
|
||||
|
||||
# 检测 DEV_BACKEND 环境变量
|
||||
dev_backend = os.environ.get("DEV_BACKEND", "").lower()
|
||||
if dev_backend == "xpu":
|
||||
|
@ -237,6 +244,8 @@ class VersionInfo:
|
|||
backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}"
|
||||
elif torch.xpu.is_available():
|
||||
backend_version = f"xpu"
|
||||
elif KTRANSFORMERS_BUILD_NPU:
|
||||
backend_version = f"npu{torch_npu.__version__}"
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.")
|
||||
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
||||
|
@ -509,6 +518,8 @@ class CMakeBuild(BuildExtension):
|
|||
cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"]
|
||||
elif KTRANSFORMERS_BUILD_XPU:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
|
||||
elif KTRANSFORMERS_BUILD_NPU:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_NPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"]
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.")
|
||||
|
||||
|
@ -636,10 +647,12 @@ elif MUSA_HOME is not None:
|
|||
)
|
||||
elif torch.xpu.is_available(): #XPUExtension is not available now.
|
||||
ops_module = None
|
||||
elif KTRANSFORMERS_BUILD_NPU:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.")
|
||||
|
||||
if not torch.xpu.is_available():
|
||||
if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU:
|
||||
ext_modules = [
|
||||
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||
ops_module,
|
||||
|
@ -660,10 +673,20 @@ if not torch.xpu.is_available():
|
|||
ext_modules.append(
|
||||
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
|
||||
)
|
||||
else:
|
||||
elif torch.xpu.is_available():
|
||||
ext_modules = [
|
||||
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||
]
|
||||
elif KTRANSFORMERS_BUILD_NPU:
|
||||
ext_modules = [
|
||||
CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")),
|
||||
]
|
||||
if with_balance:
|
||||
print("using balance_serve")
|
||||
ext_modules.append(
|
||||
CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve"))
|
||||
)
|
||||
|
||||
|
||||
setup(
|
||||
name=VersionInfo.PACKAGE_NAME,
|
||||
|
|
Loading…
Add table
Reference in a new issue