diff --git a/.gitignore b/.gitignore index 38bb53c..b6c97c0 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,7 @@ ktransformers/tests/chat_txt.txt mmlu_result* ktransformers/ktransformers_ext/cuda_musa/ test_prompt.txt -csrc/demo \ No newline at end of file +csrc/demo +CMakeFiles +kvc2/ +sched/ \ No newline at end of file diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt index 0ed4ef4..cbee533 100644 --- a/csrc/ktransformers_ext/CMakeLists.txt +++ b/csrc/ktransformers_ext/CMakeLists.txt @@ -42,6 +42,7 @@ option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF) option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF) +option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF) # Architecture specific # TODO: probably these flags need to be tweaked on some architectures @@ -306,6 +307,17 @@ elseif (UNIX) endif() elseif (KTRANSFORMERS_USE_XPU) add_compile_definitions(KTRANSFORMERS_USE_XPU=1) + elseif (KTRANSFORMERS_USE_NPU) + include(CheckLanguage) + check_language(CUDA) + if(CMAKE_CUDA_COMPILER) + message(STATUS "CUDA detected") + find_package(CUDAToolkit REQUIRED) + include_directories(${CUDAToolkit_INCLUDE_DIRS}) + endif() + message(STATUS "enabling CUDA") + enable_language(CUDA) + add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) else() find_package(CUDA REQUIRED) include_directories("${CUDA_INCLUDE_DIRS}") diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp index f0aeaa5..14f5bc3 100644 --- a/csrc/ktransformers_ext/ext_bindings.cpp +++ b/csrc/ktransformers_ext/ext_bindings.cpp @@ -9,7 +9,7 @@ **/ // Python bindings #include "cpu_backend/cpuinfer.h" -#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) +#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU) #include "device_launch_parameters.h" #endif #include "llamafile/flags.h" diff --git a/doc/zh/DeepSeekR1_tutorial_zh_for_Ascend_NPU.md b/doc/zh/DeepSeekR1_tutorial_zh_for_Ascend_NPU.md new file mode 100644 index 0000000..4dcc42f --- /dev/null +++ b/doc/zh/DeepSeekR1_tutorial_zh_for_Ascend_NPU.md @@ -0,0 +1,165 @@ +# 部署 + +## 物理机安装 + +部署满血版DeepseekV3,需要机器物理内存能够存放下全部路由专家的权重,约400GB。 + +目前支持的NPU型号:**800I A2**。 + +在技术人员的支持下完成硬件安装。 + +## 系统安装 + +根据网页[昇腾兼容性查询助手](https://www.hiascend.com/hardware/compatibility)查询,选用系统Ubuntu 22.04 for aarch64,内核5.15.0-25-generic,并禁止系统自动更新。系统镜像获取链接:[ubuntu-old-releases](https://mirrors.aliyun.com/oldubuntu-releases/releases/22.04)。 + +## HDK安装 + +选择[Ascend HDK 25.0.RC1](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=32&cann=8.1.RC1.beta1&driver=Ascend+HDK+25.0.RC1)进行安装,安装方式参考[昇腾社区HDK安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1beta1/softwareinst/instg/instg_0005.html?Mode=PmIns&InstallType=local&OS=Ubuntu&Software=cannToolKit)。 + + +## Conda部署 + +建议按照最新[Installation Guide - kTransformers](https://kvcache-ai.github.io/ktransformers/en/install.html)部署开发环境,此处注意Python版本要求3.11(其他版本未验证),arm平台不需要安装cpufeature包。 + +安装conda/miniconda + +```bash +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh +bash ~/Miniconda3-latest-Linux-aarch64.sh +``` + +部署Python环境: + +```bash +conda create -n py311 python=3.11 +conda activate py311 +conda install -c conda-forge libstdcxx-ng # 安装`GLIBCXX-3.4.32` +pip3 install numpy==1.26.4 # 适配torch/torch_npu +pip3 install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu +pip3 install packaging ninja transformers==4.43.2 fire protobuf attrs decorator cloudpickle ml-dtypes scipy tornado absl-py psutil +#pip3 install cpufeature # only for x86 +``` + +## CANN安装 + +选择[CANN 8.1.RC1.beta1](https://www.hiascend.com/developer/download/community/result?from=firmware&product=4&model=32&cann=8.1.RC1.beta1)进行安装,安装方式参考[昇腾社区CANN安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1beta1/softwareinst/instg/instg_0007.html?Mode=PmIns&InstallType=local&OS=Ubuntu&Software=cannToolKit)。 + +需要安装ToolKit,Kernel和NNAL。 + +## torch_npu(op-plugin)安装 + +获取最新的仓库代码:[op-plugin Gitee](https://gitee.com/ascend/op-plugin) + +由于涉及新增算子,公网pypi内提供的torch_npu暂时无法直接使用,需要使用适配过的op-plugin来编译生成所需的torch_npu包,目前还无法从公网获取。 # TODO + +在访问github和gitee的网络通畅时,执行下述代码完成编译和安装torch_npu: + +```bash +cd op-plugin +source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准 +source /usr/local/Ascend/nnal/atb/set_env.sh # 以实际NNAL安装路径为准 +bash install.sh --python=3.11 --pytorch=v2.3.1-7.0.0 # 实际生成的torch_npu whl安装文件在{op-plugin项目地址}/dist中 +``` + +## 权重准备 + +目前,为了满足性能和精度的要求,我们需要准备两份权重,并使用提供的权重合并脚本对权重进行合并,最终只会使用合并后的权重。 + +Q4权重:[DeepSeek-R1-Q4_K_M](https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/files) + +W8A8权重:[DeepSeek-R1-W8A8](https://modelers.cn/models/MindSpore-Lab/DeepSeek-R1-W8A8/tree/main) + +使用[merge_safetensor_gguf.py](../../merge_tensors/merge_safetensor_gguf.py)来合并Q4和W8A8权重: + +```bash +python merge_safetensor_gguf.py --safetensor_path /mnt/weights/DeepSeek-R1-Q4_K_M --gguf_path /mnt/weights/DeepSeek-R1-W8A8 --output_path /mnt/weights/DeepSeek-R1-q4km-w8a8 --safetensors_format w8a8 +``` + +## 图下沉部署 + +图下沉所需的二进制文件随仓库给出:[ktransformers/util/npu_graph_so](../../ktransformers/util/npu_graph_so)。 + +部署图下沉功能,需要做相关文件替换,以arm平台为例: + +```bash +mv /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so.bak +cp ktransformers/util/npu_graph_so/arm/libruntime.so /usr/local/Ascend/ascend-toolkit/latest/lib64/libruntime.so +``` + +开启图下沉功能,需要添加如下环境变量: + +```bash +export CAPTURE_PLUGIN_PATH=ktransformers/util/npu_graph_so/arm +export TASK_QUEUE_ENABLE=0 # 保证算子下发顺序有序 +``` + + +## kTransformers部署 + +将项目文件部署到机器上: + +- 对于arm平台,注释掉`./requirements-local_chat.txt`中的`cpufeature`。 +- 对于arm平台,做如下替换: + ```bash + cp ./for_arm/CMakeLists.txt ./csrc/ktransformers_ext/CMakeLists.txt + cp ./for_arm/iqk_mul_mat.inc ./third_party/llamafile/iqk_mul_mat.inc + cp ./for_arm/sgemm.cpp ./third_party/llamafile/sgemm.cpp + cp ./for_arm/tinyblas_cpu_sgemm.inc ./third_party/llamafile/tinyblas_cpu_sgemm.inc + cp ./for_arm/setup.py ./setup.py + ``` +- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`(以实际CANN-TOOLKIT安装路径为准)。 +- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。 +- 执行`bash install.sh`,等待安装完成。 + +此处给出示例local_chat的启动脚本(由于使用了相对路径,需将该脚本放至项目的根路径下): + +```bash +#!/bin/bash +export CAPTURE_PLUGIN_PATH=ktransformers/util/npu_graph_so/arm +export USE_MERGE=0 +export INF_NAN_MODE_FORCE_DISABLE=1 +export TASK_QUEUE_ENABLE=0 +#export PROF_DECODE=1 +#export PROF_PREFILL=1 + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +torchrun --nproc_per_node 1 \ + --master_port 25565 \ + -m ktransformers.local_chat \ + --cpu_infer 20 \ + --model_path /mnt/weights/DeepSeek-R1-q4km-w8a8 \ + --gguf_path /mnt/weights/DeepSeek-R1-q4km-w8a8 \ + --optimize_config_path ./ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-800IA2-npu.yaml \ + --use_cuda_graph True \ + --max_new_tokens 500 \ + --tp 1 +``` + +相关参数说明: + +- `--model_path`:kTransformers原生参数,str,此处用来指定合并后的模型文件路径 +- `--gguf_path`:kTransformers原生参数,str,此处用来指定合并后的模型文件路径 +- `--cpu_infer`:kTransformers原生参数,int,用来控制CPU侧实际worker线程数,非必选 +- `--optimize_config_path`:kTransformers原生参数,str,用来指定所用的模型优化配置文件,需要注意相对路径的使用,此处为**必选** +- `--use_cuda_graph`:kTransformers原生参数,bool,为True表示开启图下沉,为False表示关闭图下沉,非必选 +- `--max_new_tokens`:kTransformers原生参数,int,当统计到输出的tokens数量达到该值时,会直接中止输出,非必选 +- `--tp`:新增参数,int,用于开启tensor model parallel功能,目前local_chat只支持tp大小与ws大小相同(不支持local_chat使用多dp),非必选 + + +# 其他问题 + +## 可能存在的其他依赖问题 + +ImportError: libhccl.so: cannot open shared object file: No such file or directory + +```bash +source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准 +``` + +ImportError: libascend_hal.so: cannot open shared object file: No such file or directory + +```bash +export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH # 以实际Driver安装路径为准 +``` \ No newline at end of file diff --git a/install_for_npu.sh b/install_for_npu.sh new file mode 100644 index 0000000..585cb9a --- /dev/null +++ b/install_for_npu.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +# clear build dirs +rm -rf build +rm -rf *.egg-info +rm -rf csrc/build +rm -rf csrc/ktransformers_ext/build +rm -rf csrc/ktransformers_ext/cuda/build +rm -rf csrc/ktransformers_ext/cuda/dist +rm -rf csrc/ktransformers_ext/cuda/*.egg-info +rm -rf ~/.ktransformers +echo "Installing python dependencies from requirements.txt" +pip install -r requirements-local_chat.txt +pip install -r ktransformers/server/requirements.txt +echo "Installing ktransformers" +KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation +pip install third_party/custom_flashinfer/ + +echo "Installation completed successfully" \ No newline at end of file diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 3bd60f9..5ebbfcd 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -67,7 +67,7 @@ attn: page_size: 256 chunk_size: 256 kvc2: - gpu_only: false + gpu_only: true utilization_percentage: 1.0 cpu_memory_size_GB: 500 disk_path: /mnt/data/kvc \ No newline at end of file diff --git a/ktransformers/operators/ascend/ascend_attention.py b/ktransformers/operators/ascend/ascend_attention.py new file mode 100644 index 0000000..20afc45 --- /dev/null +++ b/ktransformers/operators/ascend/ascend_attention.py @@ -0,0 +1,467 @@ +import os +import warnings +from typing import Optional, Tuple + +import torch +import torch_npu +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from transformers.cache_utils import Cache + +from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, allreduce_wrapper +from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.utils import get_compute_capability, get_use_npu_graph, CUR_DEVICE +from ktransformers.util.vendors import device_manager, GPUVendor +from ktransformers.util import utils + + +def apply_rotary_pos_emb_fusion(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + +class MatMulOps(object): + def execute(self, x_input): + """ + :param x, weight, quant_bia, deq_scale + :return: + """ + quant_out = x_input[0] + weight = x_input[1] + quant_bia = x_input[2] + deq_scale = x_input[3] + return [torch_npu.npu_quant_matmul(quant_out, weight.T, deq_scale, bias=quant_bia, output_dtype=torch.float16)] + + +class MatMulOpsAtb(object): + def execute(self, x_input): + """ + :param x, weight, quant_bia, deq_scale + :return: + """ + x = x_input[0] + weight = x_input[1] + quant_bia = x_input[2] + deq_scale = x_input[3] + target_shape = (x.shape[0], x.shape[-2], weight.shape[-2]) + target_tensor = torch.zeros(target_shape, dtype=torch.float16, device=x.device) + torch_npu.torch_npu._npu_matmul_dequant(x, weight, quant_bia, deq_scale, target_tensor) + return [target_tensor] + + +class DynamicQuantOps(object): + """ + :param x, scale, offset + :return + """ + + def execute(self, x_input): + out = torch.empty_like(x_input[0], dtype=torch.int8) + torch_npu._npu_quantize_per_tensor(x_input[0], x_input[1], x_input[2], out) + return [out] + + +class KDeepseekV2AttentionW8A8A2(BaseInjectedModule, DeepseekV2Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + attn_mask: Optional[torch.Tensor] = None + + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + absorb_for_prefill: bool = False, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, + **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + self.mla_wrapper = None + tp = get_tensor_parallel_size() + if tp > 1: + self.num_heads //= tp + self.absorb_for_prefill = absorb_for_prefill + + self.use_merge = os.getenv("USE_MERGE", "0") + if self.use_merge == "0": + print("--Use ATB FA-MLA and PA-MLA OP !--") + self.elewise_quant = DynamicQuantOps() + self.matmulDequant_operation = MatMulOpsAtb() + self.matmulDequant_operation_aclnn = MatMulOps() + elif self.use_merge == "1": + print("--Use torch npu FA OP !--") + else: + print("--Use default op! --") + + @allreduce_wrapper + def forward_chunck( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0] + q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight, + self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0] + q_a_proj_out = self.q_a_layernorm(q_a_proj_out) + q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0] + q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight, + self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0] + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0] + compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight, + self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0] + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + kv_seq_len = k_pe.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(q_pe, position_ids) + q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin) + + # update KV + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + k_pe = k_pe.transpose(1, 2) # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] + compressed_kv = compressed_kv.unsqueeze(2) # compressed_kv [bsz, q_len, self.kv_lora_rank] + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:, :, :attention_mask.size(-1), :] + compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:, :, :attention_mask.size(-1), :] + + weight_uk = self.q_absorb + weight_uv = self.out_absorb + + # ATB-MLA-FA+PA + if self.use_merge == "0" and q_len != 1: + current_sqenLen = past_key_value.get_seq_length(self.layer_idx) + attention_mask = attention_mask[0, :, :, :current_sqenLen].squeeze(0).squeeze(0) + + compressed_kv = compressed_kv[:, :, :current_sqenLen, :] # all KV until current chunk + k_pe = k_pe[:, :, :current_sqenLen, :] + + k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1) + k_up = torch.matmul(compressed_kv, weight_uk.mT) + v_up = torch.matmul(compressed_kv, weight_uv) + + qTensor = torch.cat((q_nope, q_pe), dim=-1).transpose(1, 2).contiguous().view( + bsz * q_len, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim)) + kTensor = torch.cat((k_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view( + bsz * current_sqenLen, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim)) + vTensor = v_up.transpose(1, 2).contiguous().view(bsz * current_sqenLen, self.num_heads, self.v_head_dim) + + seq_len_data = [q_len] * bsz + seq_len = torch.tensor(seq_len_data, dtype=torch.int32, device=vTensor.device) + seq_len_host = torch.tensor(seq_len_data, dtype=torch.int32) + + attn_output = torch.ones((qTensor.shape[0], qTensor.shape[1], vTensor.shape[-1]), + dtype=qTensor.dtype, device=vTensor.device) + torch_npu._npu_flash_attention_mla(qTensor, kTensor, vTensor, attention_mask, seq_len, seq_len_host, + self.softmax_scale, self.num_heads, self.num_heads, attn_output) + + if attn_output.size() != (bsz * q_len, self.num_heads, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0] + attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight, + self.o_proj.quant_bias, self.o_proj.deq_scale])[0] + + return attn_output, None, past_key_value + + elif self.use_merge == "0" and q_len == 1: + return self.forward_paged(q_pe=q_pe, + q_nope=q_nope, + compressed_kv_with_k_pe=compressed_kv_with_k_pe, + past_key_value=past_key_value, + cache_position=cache_position) + + if self.use_merge == "1": + k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1) + k_up = torch.matmul(compressed_kv, weight_uk.mT) + v_up = torch.matmul(compressed_kv, weight_uv) + qTensor = torch.cat((q_nope, q_pe), dim=-1) + kTensor = torch.cat((k_up, k_pe_repeated), dim=-1) + vTensor = torch.cat((v_up, k_pe_repeated), dim=-1) + + if q_len != 1: + attn_output = torch_npu.npu_prompt_flash_attention( + qTensor, kTensor, vTensor, + num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD") + else: + attn_output = torch_npu.npu_incre_flash_attention( + qTensor, kTensor, vTensor, + num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD") + attn_output = attn_output[:, :, :, :self.v_head_dim] + else: + q_nope = torch.matmul(q_nope, self.q_absorb) + + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale + + compressed_kv = compressed_kv.squeeze(1) + """ + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + """ + if attention_mask is not None: + """ + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + """ + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q_pe.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) + + attn_output = torch.matmul(attn_output, self.out_absorb) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + def forward_paged( + self, + q_pe: torch.Tensor, + q_nope: torch.Tensor, + compressed_kv_with_k_pe: torch.Tensor, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, _, q_len, _ = q_nope.size() + q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb) # torch.Size([1, 128, 1, 512]) + compressed_kv = compressed_kv_with_k_pe.permute(0, 2, 1, 3) + kvCache = compressed_kv[:, :, :, :self.kv_lora_rank].contiguous() + kRopeCache = compressed_kv[:, :, :, self.kv_lora_rank:].contiguous() + + if get_use_npu_graph(): + from ktransformers.util.npu_graph_runner import get_or_create_runner + npu_graph_runner = get_or_create_runner(CUR_DEVICE) + stream = npu_graph_runner.main_stream + if npu_graph_runner.past_key_value is None: + npu_graph_runner.past_key_value = past_key_value + if npu_graph_runner.workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, + kvCache, + kvCache, + query_rope=q_pe, + key_rope=kRopeCache, + num_heads=self.num_heads, + num_key_value_heads=1, + input_layout="BNSD", + atten_mask=None, + scale=self.softmax_scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=past_key_value.page_table_list[self.layer_idx], + block_size=past_key_value.page_size, + actual_seq_lengths_kv=past_key_value.position + ) + npu_graph_runner.workspace = workspace + attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=CUR_DEVICE) + softmax_lse = torch.empty(1, dtype=torch.float16, device=CUR_DEVICE) + npu_graph_runner.ifa_param.append((q_nope, kvCache, q_pe, kRopeCache, self.num_heads, + self.softmax_scale, self.layer_idx, attn_output, softmax_lse)) + eventTmp = torch.npu.ExternalEvent() + npu_graph_runner.event.append(eventTmp) + eventTmp.wait(stream) + eventTmp.reset(stream) + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + kvCache, + kvCache, + workspace=npu_graph_runner.workspace, + query_rope=q_pe, + key_rope=kRopeCache, + num_heads=self.num_heads, + num_key_value_heads=1, + input_layout="BNSD", + atten_mask=None, + scale=self.softmax_scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=past_key_value.page_table_list[self.layer_idx], + block_size=past_key_value.page_size, + actual_seq_lengths_kv=past_key_value.position, + out=[attn_output, softmax_lse]) + handle = torch.npu.graph_task_group_end(stream) + npu_graph_runner.handle.append(handle) + else: + attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( + q_nope, + kvCache, + kvCache, + query_rope=q_pe, + key_rope=kRopeCache, + num_heads=self.num_heads, + num_key_value_heads=1, + input_layout="BNSD", + atten_mask=None, + scale=self.softmax_scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=past_key_value.page_table_list[self.layer_idx], + block_size=past_key_value.page_size, + actual_seq_lengths_kv=past_key_value.position, + ) + + attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb) + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0] + attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight, + self.o_proj.quant_bias, self.o_proj.deq_scale])[0] + return attn_output, None, past_key_value + + def forward_windows( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if q_len <= self.chunck_size: + return self.forward_chunck( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs + ) + + assert output_attentions is False, "output_attentions is not supported when using chunked attention" + attn_output = None + cur_idx = 0 + while cur_idx < q_len: + if attention_mask is not None: + chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] + else: + # generate chunk_mask automatically. + self.attn_mask = \ + torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ + if self.attn_mask is None \ + else self.attn_mask + self.attn_mask[:, :, :, cur_idx:min(cur_idx + self.chunck_size, past_key_value.max_cache_len)] = \ + -65504.0 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1) \ + [:, :min(self.chunck_size, min(past_key_value.max_cache_len - cur_idx, self.chunck_size))] + self.attn_mask[:, :, :, cur_idx + self.chunck_size:] = -65504.0 + self.attn_mask[:, :, :, :cur_idx] = 0 + chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len - cur_idx)) + + cur_output, _, _ = self.forward_chunck( + hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], + chunk_mask, + position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], + past_key_value, + output_attentions, + use_cache, + cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], + **kwargs + ) + cur_idx += self.chunck_size + if attn_output is None: + attn_output = cur_output + else: + attn_output = torch.cat((attn_output, cur_output), dim=-2) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + return self.forward_windows( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) \ No newline at end of file diff --git a/ktransformers/operators/ascend/ascend_experts.py b/ktransformers/operators/ascend/ascend_experts.py new file mode 100644 index 0000000..c7bdb8f --- /dev/null +++ b/ktransformers/operators/ascend/ascend_experts.py @@ -0,0 +1,192 @@ +import bisect + +import acl +import torch +import numpy as np +from torch import nn +from transformers import PretrainedConfig + +from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, get_tensor_parallel_group +from ktransformers.operators.experts import KExpertsCPU, KTransformersExperts, EXPERTS_MAP, KDeepseekV3MoE, cuda_graphs +from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.utils import CUR_DEVICE, get_use_npu_graph, InferenceState + + +class KExpertsCPUW8A8(KExpertsCPU): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cpu", + **kwargs + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.input_tensor_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + self.expert_ids_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) + self.weights_cpu_graph = torch.zeros((1, self.config.num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) + self.output_cpu_graph = torch.zeros((1, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) + self.bsz_tensor_cpu_graph = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True) + + def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): + if get_use_npu_graph(): + self.cpu_infer.submit(self.moe.forward(self.expert_ids_cpu_graph.size(0), + self.expert_ids_cpu_graph.size(1), + self.expert_ids_cpu_graph.data_ptr(), + self.weights_cpu_graph.data_ptr(), + self.input_tensor_cpu_graph.data_ptr(), + self.output_cpu_graph.data_ptr(), + self.bsz_tensor_cpu_graph.data_ptr())) + self.cpu_infer.sync() + else: + if bsz_tensor is None: + bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) + org_type = input_tensor.dtype + input_tensor = input_tensor.contiguous().cpu() + input_tensor = input_tensor.to(torch.bfloat16) + expert_ids = expert_ids.contiguous().cpu() + weights = weights.contiguous().to(torch.float32).cpu() + bsz_tensor = bsz_tensor.contiguous().cpu() + output = torch.empty_like(input_tensor).contiguous() + self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr())) + self.cpu_infer.sync() + return output.to(org_type).to(device=CUR_DEVICE) + + +EXPERTS_MAP["KExpertsCPUW8A8"] = KExpertsCPUW8A8 + + +class KTransformersExpertsW8A8(KTransformersExperts): + def forward(self, input_tensor, expert_ids, weights): + if self.mode == InferenceState.GENERATE: + assert self.generate_experts is not None, "generate_experts is None" + return self.generate_experts.forward(input_tensor, expert_ids, weights) + elif self.mode == InferenceState.PREFILL: + assert self.prefill_experts is not None, "prefill_experts is None" + return self.prefill_experts.forward(input_tensor, expert_ids, weights) + else: + raise ValueError("load or set_inference_mode before forward") + + +class KDeepseekV3MoEW8A8(KDeepseekV3MoE): + def forward_tp(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + rank = torch.distributed.get_rank() + def share_experts_forward(): + if self.config.n_shared_experts is not None: + return self.shared_experts(identity).squeeze(0) + if rank == 0: + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1) + if get_use_npu_graph(): + from ktransformers.util.npu_graph_runner import get_or_create_runner + npu_graph_runner = get_or_create_runner(CUR_DEVICE) + event = torch.npu.Event() + event.record(npu_graph_runner.main_stream) + with torch.npu.stream(npu_graph_runner.update_stream): + event.wait(npu_graph_runner.update_stream) + y_ = share_experts_forward() if share_experts_forward is not None else None + event.record(npu_graph_runner.update_stream) + org_type = hidden_states.dtype + input_tensor = hidden_states.to(torch.bfloat16) + topk_weight = topk_weight.contiguous().to(torch.float32) + self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight) + self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True) + self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True) + self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True) + + npu_graph_runner.launch_callback( + self.cpu_moe_kexperts, + self.moe_kexperts_param, + 1, npu_graph_runner.stream) + + output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True) + y = output_npu_graph.to(org_type) + event.wait(npu_graph_runner.main_stream) + else: + y = self.moe_kexperts(hidden_states, topk_idx, topk_weight) + y_ = share_experts_forward() if share_experts_forward is not None else None + y = y.view(*orig_shape).to(device=hidden_states.device) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + y = torch.zeros(orig_shape, dtype=torch.float16, device=CUR_DEVICE) + y_ = share_experts_forward() if share_experts_forward is not None else None + torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group()) + if self.config.n_shared_experts is not None: + y += y_ + return y + + def forward(self, hidden_states): + tp_size = get_tensor_parallel_size() + world_size = torch.distributed.get_world_size() + if tp_size > 1 and world_size == tp_size: + return self.forward_tp(hidden_states) + identity = hidden_states + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + y_ = None + + # only for generate phase + # if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and False: + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) + if self.config.n_shared_experts is not None: + y_ = self.shared_experts(identity).squeeze(0) + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + y += y_ + y.resize_(*orig_shape) + return y + + def share_experts_forward(): + if self.config.n_shared_experts is not None: + return self.shared_experts(identity).squeeze(0) + + cuda_graph_idx = bisect.bisect_left(cuda_graphs, 1) + if get_use_npu_graph(): + from ktransformers.util.npu_graph_runner import get_or_create_runner + npu_graph_runner = get_or_create_runner(CUR_DEVICE) + event = torch.npu.Event() + event.record(npu_graph_runner.main_stream) + with torch.npu.stream(npu_graph_runner.update_stream): + event.wait(npu_graph_runner.update_stream) + y_ = share_experts_forward() if share_experts_forward is not None else None + event.record(npu_graph_runner.update_stream) + org_type = hidden_states.dtype + input_tensor = hidden_states.to(torch.bfloat16) + topk_weight = topk_weight.contiguous().to(torch.float32) + self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight) + self.experts.generate_experts.input_tensor_cpu_graph.copy_(input_tensor, non_blocking=True) + self.experts.generate_experts.expert_ids_cpu_graph.copy_(topk_idx, non_blocking=True) + self.experts.generate_experts.weights_cpu_graph.copy_(topk_weight, non_blocking=True) + + npu_graph_runner.launch_callback( + self.cpu_moe_kexperts, + self.moe_kexperts_param, + 1, npu_graph_runner.stream) + + output_npu_graph = self.experts.generate_experts.output_cpu_graph.to(CUR_DEVICE, non_blocking=True) + y = output_npu_graph.to(org_type) + event.wait(npu_graph_runner.main_stream) + else: + y = self.moe_kexperts(hidden_states, topk_idx, topk_weight) + y_ = share_experts_forward() if share_experts_forward is not None else None + y = y.view(*orig_shape).to(device=hidden_states.device) + + if self.config.n_shared_experts is not None: + y += y_ + return y + + @torch.no_grad() + def cpu_moe_kexperts(self, moe_kexperts_param) -> torch.Tensor: + x, topk_ids, topk_weight = moe_kexperts_param + self.moe_kexperts(x, topk_ids, topk_weight) + + @torch.no_grad() + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = self.experts(x, topk_ids, topk_weight) + return outs \ No newline at end of file diff --git a/ktransformers/operators/ascend/ascend_gate.py b/ktransformers/operators/ascend/ascend_gate.py new file mode 100644 index 0000000..9829b2d --- /dev/null +++ b/ktransformers/operators/ascend/ascend_gate.py @@ -0,0 +1,43 @@ +import torch +import torch_npu +import torch.nn as nn +import torch.nn.functional as F +from ktransformers.operators.gate import KMoEGate +from ktransformers.util import utils + + +class KDeepseekV3GateA2(KMoEGate): + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None): + device = utils.CUR_DEVICE + if device is None: + device = self.device + if w is None: + w = self.load_weights(device=device) + + if isinstance(w, dict): + self.weight_type = w["weight_type"] + self.e_score_correction_bias_type = w["e_score_correction_bias_type"] + self.orig_module.weight = nn.Parameter(w["weight"]) + self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) + else: + raise ValueError("Invalid weight type") + self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32)) + self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32)) + + def forward(self, hidden_states) -> torch.Tensor: + h = hidden_states.shape[-1] + # compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states.type(torch.float32), self.weight, None) + topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k( + logits, + k=self.top_k, + bias=self.e_score_correction_bias, + k_group=self.topk_group, + group_count=self.n_group, + group_select_mode=1, + renorm=0, + norm_type=1, + routed_scaling_factor=self.routed_scaling_factor, + eps=float(1e-20)) + return topk_idx.type(torch.int64), topk_weight \ No newline at end of file diff --git a/ktransformers/operators/ascend/ascend_layernorm.py b/ktransformers/operators/ascend/ascend_layernorm.py new file mode 100644 index 0000000..465ec0b --- /dev/null +++ b/ktransformers/operators/ascend/ascend_layernorm.py @@ -0,0 +1,38 @@ +import torch +import torch_npu +from torch import nn +from transformers import PretrainedConfig + +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util import utils +from ktransformers.util.custom_gguf import GGUFLoader + + +class KDeepseekV3RMSNormW8A8(BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "npu", + generate_device: str = "npu", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.weight = nn.Parameter(torch.ones(self.orig_module.hidden_size)) + self.bias = nn.Parameter(torch.ones(self.orig_module.hidden_size)) + self.variance_epsilon = self.orig_module.variance_epsilon + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + out = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + self.bias + return out.to(input_dtype) + + def load(self): + self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(utils.CUR_DEVICE) + self.bias = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".bias").to(utils.CUR_DEVICE) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.bias is not None: + self.bias = None \ No newline at end of file diff --git a/ktransformers/operators/ascend/ascend_linear.py b/ktransformers/operators/ascend/ascend_linear.py new file mode 100644 index 0000000..1b8e959 --- /dev/null +++ b/ktransformers/operators/ascend/ascend_linear.py @@ -0,0 +1,298 @@ +from abc import abstractmethod + +import torch +import torch_npu +import torch.distributed as dist +from torch import nn +from transformers import PretrainedConfig + +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.operators.linear import KLinearBase, LINEAR_MAP +from ktransformers.util.ascend.ascend_utils import ( + get_safetensors_cut_weight, + get_tensor_parallel_size, + get_tensor_parallel_group +) +from ktransformers.util import utils +from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.utils import InferenceState + + +class KLinearW8A8(KLinearBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + + def load_weight(self, override_key: str | None = None, device: str | None = None): + if override_key is not None: + keys = override_key + else: + keys = [self.key] + fake_tensor = torch.tensor([1]) + for key in keys: + if device is None: + device = utils.CUR_DEVICE + if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map: + if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map: + qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight") + deq_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.deq_scale") + quant_bias = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.quant_bias") + input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale") + input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset") + tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset) + return tensors + elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map: + if key.endswith("ffn_gate_shexp"): + parts = key.split(".") + layer = parts[1] + gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight") + gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t() + up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight") + up_weight = get_safetensors_cut_weight(self.key, up_weight).t() + gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale") + gate_scale = get_safetensors_cut_weight(self.key, gate_scale) + up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale") + up_scale = get_safetensors_cut_weight(self.key, up_scale) + gate_up_weight = torch.cat((gate_weight, up_weight), 1) + gate_up_scale = torch.cat((gate_scale, up_scale), 0) + gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset") + up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset") + gate_up_offset = torch.cat((gate_offset, up_offset), 0) + tensors = (gate_up_weight, gate_up_scale, gate_up_offset) + elif key.endswith("ffn_up_shexp"): + return fake_tensor + else: + qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight") + weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale") + weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset") + tensors = (qweight, weight_scale, weight_offset) + return tensors + else: + weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight") + weight = get_safetensors_cut_weight(self.key, weight) + return weight + else: + raise FileNotFoundError(f"Weight file not found for key {key}") + + @abstractmethod + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = "cuda"): + pass + + @abstractmethod + def unload(self): + pass + + +class KLinearTorchW8A8A2(KLinearW8A8): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cuda", + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.weight = None + self.input_scale = None + self.input_offset = None + self.quant_bias = None + self.deq_scale = None + self.weight_scale = None + self.weight_offset = None + + def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor: + tp = get_tensor_parallel_size() + if tp == 1: + out = torch.zeros((x.shape[0], x.shape[1], self.weight.shape[-1]), dtype=torch.float16, device=x.device) + torch_npu._npu_matmul_pp(x, self.weight, out) + else: + tp_size = get_tensor_parallel_size() + tp_group = get_tensor_parallel_group() + batch_size = x.shape[0] + seq_length = x.shape[1] + lm_sep_size = tp_size + lm_head_group = tp_group + gathered_list = [torch.empty_like(x) for _ in range(lm_sep_size)] + dist.all_gather(gathered_list, x, group=lm_head_group) + input_full = torch.stack(gathered_list, dim=0) + input_full = input_full.squeeze(dim=1) + torch_npu.npu_format_cast_(input_full, 2) + local_logits = torch.zeros((input_full.shape[0], input_full.shape[1], self.weight.shape[-1]), + dtype=torch.float16, device=input_full.device) + torch_npu._npu_matmul_pp(input_full, self.weight, local_logits) + local_logits_transpose = local_logits.transpose(2, 1).reshape(-1, batch_size * seq_length) + del local_logits + output_tensor = torch.empty_like(local_logits_transpose) + dist.all_to_all_single(output_tensor, local_logits_transpose, group=lm_head_group) + del local_logits_transpose + output_tensor = output_tensor.transpose(1, 0) + out = output_tensor.view(batch_size, seq_length, -1) + del output_tensor + return out + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None): + if device is None: + device = self.device + device = utils.CUR_DEVICE + if w is None: + w = self.load_weight() + if isinstance(w, nn.Parameter): + try: + self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T.contiguous() + except: + self.weight = w.to(dtype=self.dtype).T.contiguous() + self.weight = self.weight.to(device) + if self.has_bias: + self.bias = self.bias.to(device) + elif isinstance(w, tuple): + w_list = list(w) + if len(w_list) == 3: + self.weight = w_list[0] + self.weight_scale = w_list[1].view(-1) + self.weight_offset = w_list[2] + self.weight = self.weight.to(utils.CUR_DEVICE) + self.weight_scale = self.weight_scale.to(utils.CUR_DEVICE) + if self.key.endswith("ffn_gate_shexp") is not True: + self.weight = get_safetensors_cut_weight(self.key, self.weight).t() + weight_scale = get_safetensors_cut_weight(self.key, self.weight_scale) + self.weight_scale = weight_scale.clone() + del weight_scale + self.weight_offset = self.weight_offset.to(utils.CUR_DEVICE) + else: + for i in range(len(w_list)): + w_list[i] = get_safetensors_cut_weight(self.key, w_list[i]) + w_list[i] = w_list[i].to(utils.CUR_DEVICE) + self.weight = w_list[0] + self.deq_scale = w_list[1] + self.quant_bias = w_list[2] + if "attn_output" in self.key or "ffn_down" in self.key: + if torch.distributed.get_rank(get_tensor_parallel_group()) != 0: + self.quant_bias = torch.zeros_like(self.quant_bias, dtype=self.quant_bias.dtype, device=self.quant_bias.device) + self.input_scale = w_list[3] + self.input_offset = w_list[4] + elif isinstance(w, torch.Tensor): + self.weight = w.T.contiguous() + self.weight.to(device) + if "kv_b" not in self.key: + self.weight = self.weight.to(device) + torch_npu.npu_format_cast_(self.weight, 29) + else: + raise ValueError(f"Invalid weight type {self.key=} {type(w)=}") + + def unload(self): + if self.weight is not None: + self.weight = None + if self.has_bias: + self.bias = None + self.input_scale = None + self.input_offset = None + self.quant_bias = None + self.deq_scale = None + self.weight_scale = None + self.weight_offset = None + + +LINEAR_MAP["KLinearTorchW8A8A2"] = KLinearTorchW8A8A2 + + +class KTransformersLinearW8A8A2(BaseInjectedModule, KLinearW8A8): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + generate_device: str = "cuda", + generate_op: str | None = "KLinearMarlin", + prefill_device: str = "cuda", + prefill_op: str | None = "KLinearTorch", + **kwargs, + ): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) + KLinearW8A8.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + # build all the linear operators + if prefill_op is not None: + assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" + self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) + else: + self.prefill_linear = None + + if generate_op is not None: + assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" + self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) + else: + self.generate_linear = None + self.mode = InferenceState.UNLOAD + + def forward(self, x, bsz_tensor=None): + if self.mode == InferenceState.PREFILL: + assert self.prefill_linear is not None, "cpu linear is not initialized" + y = self.prefill_linear.forward(x, bsz_tensor) + else: + assert self.generate_linear is not None, "gpu linear is not initialized" + y = self.generate_linear.forward(x, bsz_tensor) + return y + + def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): + if not mode: + mode = InferenceState.GENERATE + # load to device + if mode == InferenceState.PREFILL: + self.generate_linear.unload() + self.prefill_linear.load(w=w) + self.device = self.prefill_linear.device + self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight + self.input_scale = self.prefill_linear.input_scale + self.input_offset = self.prefill_linear.input_offset + self.quant_bias = self.prefill_linear.quant_bias + self.deq_scale = self.prefill_linear.deq_scale + self.weight_scale = self.prefill_linear.weight_scale + self.weight_offset = self.prefill_linear.weight_offset + elif mode == InferenceState.GENERATE: + self.prefill_linear.unload() + self.generate_linear.load(w=w) + self.device = self.generate_linear.device + self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight + self.input_scale = self.generate_linear.input_scale + self.input_offset = self.generate_linear.input_offset + self.quant_bias = self.generate_linear.quant_bias + self.deq_scale = self.generate_linear.deq_scale + self.weight_scale = self.generate_linear.weight_scale + self.weight_offset = self.generate_linear.weight_offset + elif mode == InferenceState.UNLOAD: + self.prefill_linear.unload() + self.generate_linear.unload() + self.device = "cpu" + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + self.mode = mode + + def unload(self): + if self.prefill_linear is not None: + self.prefill_linear.unload() + if self.generate_linear is not None: + self.generate_linear.unload() + self.device = self.generate_linear.device + + def set_inference_mode(self, mode: InferenceState): + if not mode: + mode = InferenceState.GENERATE + if mode == InferenceState.GENERATE: + self.load(mode=InferenceState.GENERATE) + elif mode == InferenceState.PREFILL: + self.load(mode=InferenceState.PREFILL) + elif mode == InferenceState.UNLOAD: + self.unload() + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") \ No newline at end of file diff --git a/ktransformers/operators/ascend/ascend_mlp.py b/ktransformers/operators/ascend/ascend_mlp.py new file mode 100644 index 0000000..afe4935 --- /dev/null +++ b/ktransformers/operators/ascend/ascend_mlp.py @@ -0,0 +1,72 @@ +import torch +import torch_npu + +from ktransformers.util.ascend.ascend_utils import allreduce_wrapper +from ktransformers.util.utils import CUR_DEVICE +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP + + +class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP): + @allreduce_wrapper + def forward(self, x, is_prefill=None, use_cuda_graph=False): + original_dtype = x.dtype + quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) + dynamic_scale = dynamic_scale.view(-1) + gate_x = torch_npu.npu_quant_matmul( + quant_out, + self.orig_module.gate_proj.weight, + self.orig_module.gate_proj.weight_scale, + pertoken_scale=dynamic_scale, + bias=None, + output_dtype=original_dtype, + ) + up_x = torch_npu.npu_quant_matmul( + quant_out, + self.orig_module.up_proj.weight, + self.orig_module.up_proj.weight_scale, + pertoken_scale=dynamic_scale, + bias=None, + output_dtype=original_dtype, + ) + down_x = self.act_fn(gate_x) * up_x + down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x) + down_dynamic_scale = down_dynamic_scale.view(-1) + down_proj = torch_npu.npu_quant_matmul( + down_quant_out, + self.orig_module.down_proj.weight, + self.orig_module.down_proj.weight_scale, + pertoken_scale=down_dynamic_scale, + bias=None, + output_dtype=original_dtype, + ) + return down_proj + + +class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP): + @allreduce_wrapper + def forward(self, x, is_prefill=None, use_cuda_graph=False): + original_dtype = x.dtype + quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) + dynamic_scale = dynamic_scale.view(-1) + gate_up_x = torch_npu.npu_quant_matmul( + quant_out, + self.orig_module.gate_proj.weight, + self.orig_module.gate_proj.weight_scale, + pertoken_scale=dynamic_scale, + bias=None, + output_dtype=original_dtype, + ) + down_x = torch_npu.npu_swiglu(gate_up_x, -1) + + down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x) + down_dynamic_scale = down_dynamic_scale.view(-1) + down_proj = torch_npu.npu_quant_matmul( + down_quant_out, + self.orig_module.down_proj.weight, + self.orig_module.down_proj.weight_scale, + pertoken_scale=down_dynamic_scale, + bias=None, + output_dtype=original_dtype, + ) + return down_proj \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-800IA2-npu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-800IA2-npu.yaml new file mode 100644 index 0000000..a05551c --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-800IA2-npu.yaml @@ -0,0 +1,114 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "npu" + prefill_device: "npu" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types + kwargs: + generate_device: "npu" + prefill_device: "npu" + generate_op: "KLinearTorchW8A8A2" + prefill_op: "KLinearTorchW8A8A2" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types + kwargs: + generate_device: "npu" + prefill_device: "npu" + generate_op: "KLinearTorchW8A8A2" + prefill_op: "KLinearTorchW8A8A2" + +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8 # mlp module with custom forward function + kwargs: + generate_device: "npu" + prefill_device: "npu" + +- match: + name: "^model\\.layers\\.([0-2])\\.mlp$" + class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP" + replace: + class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1" + kwargs: + generate_device: "npu" + prefill_device: "npu" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.shared_experts$" + class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP" + replace: + class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2" + kwargs: + generate_device: "npu" + prefill_device: "npu" + +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2 + kwargs: + generate_device: "npu:0" + prefill_device: "npu:0" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8 + kwargs: + prefill_device: "npu" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPUW8A8" + out_device: "npu" + recursive: False # don't recursively inject submodules of this module + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + class: ktransformers.operators.experts.KExpertsCPU + replace: + class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8 + +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2 # optimized MLA implementation + kwargs: + generate_device: "npu" + prefill_device: "npu" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). + +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + name: "^model..*norm" + replace: + class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8 + kwargs: + generate_device: "npu" + prefill_device: "npu" \ No newline at end of file diff --git a/setup.py b/setup.py index c91d9dc..84663b9 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,13 @@ except ImportError: MUSA_HOME=None KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available() + +try: + import torch_npu + KTRANSFORMERS_BUILD_NPU = torch_npu.npu.is_available() +except ModuleNotFoundError | ImportError as e: + KTRANSFORMERS_BUILD_NPU = False + # 检测 DEV_BACKEND 环境变量 dev_backend = os.environ.get("DEV_BACKEND", "").lower() if dev_backend == "xpu": @@ -237,6 +244,8 @@ class VersionInfo: backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}" elif torch.xpu.is_available(): backend_version = f"xpu" + elif KTRANSFORMERS_BUILD_NPU: + backend_version = f"npu{torch_npu.__version__}" else: raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.") package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" @@ -509,6 +518,8 @@ class CMakeBuild(BuildExtension): cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] elif KTRANSFORMERS_BUILD_XPU: cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"] + elif KTRANSFORMERS_BUILD_NPU: + cmake_args += ["-DKTRANSFORMERS_USE_NPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"] else: raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.") @@ -636,10 +647,12 @@ elif MUSA_HOME is not None: ) elif torch.xpu.is_available(): #XPUExtension is not available now. ops_module = None +elif KTRANSFORMERS_BUILD_NPU: + pass else: raise ValueError("Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.") -if not torch.xpu.is_available(): +if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU: ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), ops_module, @@ -660,10 +673,20 @@ if not torch.xpu.is_available(): ext_modules.append( CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) ) -else: +elif torch.xpu.is_available(): ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), ] +elif KTRANSFORMERS_BUILD_NPU: + ext_modules = [ + CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), + ] + if with_balance: + print("using balance_serve") + ext_modules.append( + CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) + ) + setup( name=VersionInfo.PACKAGE_NAME,